diff options
Diffstat (limited to 'lib/Transforms')
252 files changed, 17948 insertions, 9441 deletions
diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index c795866ec0f2..06222d7e7e44 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -1,9 +1,8 @@ //===- AggressiveInstCombine.cpp ------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index f3c8bde9f8ff..44e1c45664e7 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -1,9 +1,8 @@ //===- AggressiveInstCombineInternal.h --------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 8289b2d68f8a..7c5767912fd3 100644 --- a/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -1,9 +1,8 @@ //===- TruncInstCombine.cpp -----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Coroutines/CoroCleanup.cpp b/lib/Transforms/Coroutines/CoroCleanup.cpp index 359876627fce..1fb0a114d0c7 100644 --- a/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -1,9 +1,8 @@ //===- CoroCleanup.cpp - Coroutine Cleanup Pass ---------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This pass lowers all remaining coroutine intrinsics. @@ -50,7 +49,7 @@ static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { Builder.SetInsertPoint(SubFn); auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy); auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index); - auto *Load = Builder.CreateLoad(Gep); + auto *Load = Builder.CreateLoad(FrameTy->getElementType(Index), Gep); SubFn->replaceAllUsesWith(Load); } diff --git a/lib/Transforms/Coroutines/CoroEarly.cpp b/lib/Transforms/Coroutines/CoroEarly.cpp index ac47a06281a5..692697d6f32e 100644 --- a/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/lib/Transforms/Coroutines/CoroEarly.cpp @@ -1,9 +1,8 @@ //===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This pass lowers coroutine intrinsics that hide the details of the exact @@ -98,7 +97,7 @@ void Lowerer::lowerCoroDone(IntrinsicInst *II) { Builder.SetInsertPoint(II); auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy); auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0); - auto *Load = Builder.CreateLoad(Gep); + auto *Load = Builder.CreateLoad(FrameTy, Gep); auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); II->replaceAllUsesWith(Cond); @@ -114,7 +113,7 @@ void Lowerer::lowerCoroNoop(IntrinsicInst *II) { StructType *FrameTy = StructType::create(C, "NoopCoro.Frame"); auto *FramePtrTy = FrameTy->getPointerTo(); auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, - /*IsVarArgs=*/false); + /*isVarArg=*/false); auto *FnPtrTy = FnTy->getPointerTo(); FrameTy->setBody({FnPtrTy, FnPtrTy}); diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp index 58f952b54f3a..6707aa1c827d 100644 --- a/lib/Transforms/Coroutines/CoroElide.cpp +++ b/lib/Transforms/Coroutines/CoroElide.cpp @@ -1,9 +1,8 @@ //===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This pass replaces dynamic allocation of coroutine frame with alloca and diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index 4cb0a52961cc..58bf22bee29b 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -1,9 +1,8 @@ //===- CoroFrame.cpp - Builds and manipulates coroutine frame -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This file contains classes used to discover if for a particular value @@ -53,7 +52,7 @@ public: } size_t blockToIndex(BasicBlock *BB) const { - auto *I = std::lower_bound(V.begin(), V.end(), BB); + auto *I = llvm::lower_bound(V, BB); assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block"); return I - V.begin(); } @@ -379,7 +378,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, StructType *FrameTy = StructType::create(C, Name); auto *FramePtrTy = FrameTy->getPointerTo(); auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, - /*IsVarArgs=*/false); + /*isVarArg=*/false); auto *FnPtrTy = FnTy->getPointerTo(); // Figure out how wide should be an integer type storing the suspend index. @@ -403,6 +402,7 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, if (CurrentDef == Shape.PromiseAlloca) continue; + uint64_t Count = 1; Type *Ty = nullptr; if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) { Ty = AI->getAllocatedType(); @@ -414,11 +414,18 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, Padder.addType(PaddingTy); } } + if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize())) + Count = CI->getValue().getZExtValue(); + else + report_fatal_error("Coroutines cannot handle non static allocas yet"); } else { Ty = CurrentDef->getType(); } S.setFieldIndex(Types.size()); - Types.push_back(Ty); + if (Count == 1) + Types.push_back(Ty); + else + Types.push_back(ArrayType::get(Ty, Count)); Padder.addType(Ty); } FrameTy->setBody(Types); @@ -471,11 +478,12 @@ static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) { // static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { auto *CB = Shape.CoroBegin; + LLVMContext &C = CB->getContext(); IRBuilder<> Builder(CB->getNextNode()); - PointerType *FramePtrTy = Shape.FrameTy->getPointerTo(); + StructType *FrameTy = Shape.FrameTy; + PointerType *FramePtrTy = FrameTy->getPointerTo(); auto *FramePtr = cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr")); - Type *FrameTy = FramePtrTy->getElementType(); Value *CurrentValue = nullptr; BasicBlock *CurrentBlock = nullptr; @@ -492,17 +500,41 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { if (Shape.PromiseAlloca) Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField); + // Create a GEP with the given index into the coroutine frame for the original + // value Orig. Appends an extra 0 index for array-allocas, preserving the + // original type. + auto GetFramePointer = [&](uint32_t Index, Value *Orig) -> Value * { + SmallVector<Value *, 3> Indices = { + ConstantInt::get(Type::getInt32Ty(C), 0), + ConstantInt::get(Type::getInt32Ty(C), Index), + }; + + if (auto *AI = dyn_cast<AllocaInst>(Orig)) { + if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize())) { + auto Count = CI->getValue().getZExtValue(); + if (Count > 1) { + Indices.push_back(ConstantInt::get(Type::getInt32Ty(C), 0)); + } + } else { + report_fatal_error("Coroutines cannot handle non static allocas yet"); + } + } + + return Builder.CreateInBoundsGEP(FrameTy, FramePtr, Indices); + }; + // Create a load instruction to reload the spilled value from the coroutine // frame. auto CreateReload = [&](Instruction *InsertBefore) { assert(Index && "accessing unassigned field number"); Builder.SetInsertPoint(InsertBefore); - auto *G = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index, - CurrentValue->getName() + - Twine(".reload.addr")); + + auto *G = GetFramePointer(Index, CurrentValue); + G->setName(CurrentValue->getName() + Twine(".reload.addr")); + return isa<AllocaInst>(CurrentValue) ? G - : Builder.CreateLoad(G, + : Builder.CreateLoad(FrameTy->getElementType(Index), G, CurrentValue->getName() + Twine(".reload")); }; @@ -589,8 +621,8 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front()); // If we found any allocas, replace all of their remaining uses with Geps. for (auto &P : Allocas) { - auto *G = - Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, P.second); + auto *G = GetFramePointer(P.second, P.first); + // We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here, // as we are changing location of the instruction. G->takeName(P.first); diff --git a/lib/Transforms/Coroutines/CoroInstr.h b/lib/Transforms/Coroutines/CoroInstr.h index 9a8cc5a2591c..5e19d7642e38 100644 --- a/lib/Transforms/Coroutines/CoroInstr.h +++ b/lib/Transforms/Coroutines/CoroInstr.h @@ -1,9 +1,8 @@ //===-- CoroInstr.h - Coroutine Intrinsics Instruction Wrappers -*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This file defines classes that make it really easy to deal with intrinsic diff --git a/lib/Transforms/Coroutines/CoroInternal.h b/lib/Transforms/Coroutines/CoroInternal.h index 8e690d649cf5..441c8a20f1f3 100644 --- a/lib/Transforms/Coroutines/CoroInternal.h +++ b/lib/Transforms/Coroutines/CoroInternal.h @@ -1,9 +1,8 @@ //===- CoroInternal.h - Internal Coroutine interfaces ---------*- C++ -*---===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Common definitions/declarations used internally by coroutine lowering passes. diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 9eeceb217ba8..5458e70ff16a 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1,9 +1,8 @@ //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This pass builds the coroutine frame and outlines resume and destroy parts @@ -94,7 +93,7 @@ static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { auto *FrameTy = Shape.FrameTy; auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); - auto *Index = Builder.CreateLoad(GepIndex, "index"); + auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index"); auto *Switch = Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); Shape.ResumeSwitch = Switch; @@ -230,7 +229,8 @@ static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, Builder.SetInsertPoint(OldSwitchBB->getTerminator()); auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, 0, 0, "ResumeFn.addr"); - auto *Load = Builder.CreateLoad(GepIndex); + auto *Load = Builder.CreateLoad( + Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex); auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(Load->getType())); auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); @@ -777,6 +777,8 @@ static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) { } static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { + EliminateUnreachableBlocks(F); + coro::Shape Shape(F); if (!Shape.CoroBegin) return; @@ -828,6 +830,7 @@ static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { // split. static void prepareForSplit(Function &F, CallGraph &CG) { Module &M = *F.getParent(); + LLVMContext &Context = F.getContext(); #ifndef NDEBUG Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); assert(DevirtFn && "coro.devirt.trigger function not found"); @@ -842,10 +845,12 @@ static void prepareForSplit(Function &F, CallGraph &CG) { // call void %1(i8* null) coro::LowererBase Lowerer(M); Instruction *InsertPt = F.getEntryBlock().getTerminator(); - auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); + auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context)); auto *DevirtFnAddr = Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); - auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); + FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context), + {Type::getInt8PtrTy(Context)}, false); + auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt); // Update CG graph with an indirect call we just added. CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); @@ -861,7 +866,7 @@ static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { LLVMContext &C = M.getContext(); auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), - /*IsVarArgs=*/false); + /*isVarArg=*/false); Function *DevirtFn = Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, CORO_DEVIRT_TRIGGER_FN, &M); @@ -941,7 +946,12 @@ struct CoroSplit : public CallGraphSCCPass { char CoroSplit::ID = 0; -INITIALIZE_PASS( +INITIALIZE_PASS_BEGIN( + CoroSplit, "coro-split", + "Split coroutine into a set of functions driving its state machine", false, + false) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_END( CoroSplit, "coro-split", "Split coroutine into a set of functions driving its state machine", false, false) diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index cf84f916e24b..a581d1d21169 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -1,9 +1,8 @@ //===- Coroutines.cpp -----------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -177,15 +176,15 @@ static void buildCGN(CallGraph &CG, CallGraphNode *Node) { // Look for calls by this function. for (Instruction &I : instructions(F)) - if (CallSite CS = CallSite(cast<Value>(&I))) { - const Function *Callee = CS.getCalledFunction(); + if (auto *Call = dyn_cast<CallBase>(&I)) { + const Function *Callee = Call->getCalledFunction(); if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) // Indirect calls of intrinsics are not allowed so no need to check. // We can be more precise here by using TargetArg returned by // Intrinsic::isLeaf. - Node->addCalledFunction(CS, CG.getCallsExternalNode()); + Node->addCalledFunction(Call, CG.getCallsExternalNode()); else if (!Callee->isIntrinsic()) - Node->addCalledFunction(CS, CG.getOrInsertFunction(Callee)); + Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee)); } } diff --git a/lib/Transforms/IPO/AlwaysInliner.cpp b/lib/Transforms/IPO/AlwaysInliner.cpp index 07138718ce2c..c50805692b98 100644 --- a/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/lib/Transforms/IPO/AlwaysInliner.cpp @@ -1,9 +1,8 @@ //===- InlineAlways.cpp - Code to inline always_inline functions ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -32,8 +31,17 @@ using namespace llvm; #define DEBUG_TYPE "inline" -PreservedAnalyses AlwaysInlinerPass::run(Module &M, ModuleAnalysisManager &) { - InlineFunctionInfo IFI; +PreservedAnalyses AlwaysInlinerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + // Add inline assumptions during code generation. + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + InlineFunctionInfo IFI(/*cg=*/nullptr, &GetAssumptionCache); + SmallSetVector<CallSite, 16> Calls; bool Changed = false; SmallVector<Function *, 16> InlinedFunctions; @@ -146,11 +154,20 @@ InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallSite CS) { Function *Callee = CS.getCalledFunction(); // Only inline direct calls to functions with always-inline attributes - // that are viable for inlining. FIXME: We shouldn't even get here for - // declarations. - if (Callee && !Callee->isDeclaration() && - CS.hasFnAttr(Attribute::AlwaysInline) && isInlineViable(*Callee)) - return InlineCost::getAlways("always inliner"); + // that are viable for inlining. + if (!Callee) + return InlineCost::getNever("indirect call"); + + // FIXME: We shouldn't even get here for declarations. + if (Callee->isDeclaration()) + return InlineCost::getNever("no definition"); + + if (!CS.hasFnAttr(Attribute::AlwaysInline)) + return InlineCost::getNever("no alwaysinline attribute"); + + auto IsViable = isInlineViable(*Callee); + if (!IsViable) + return InlineCost::getNever(IsViable.message); - return InlineCost::getNever("always inliner"); + return InlineCost::getAlways("always inliner"); } diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index 4663de0b049e..95a9f31cced3 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -1,9 +1,8 @@ //===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -59,11 +58,13 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/NoFolder.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" @@ -243,6 +244,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, assert(CS.getCalledFunction() == F); Instruction *Call = CS.getInstruction(); const AttributeList &CallPAL = CS.getAttributes(); + IRBuilder<NoFolder> IRB(Call); // Loop over the operands, inserting GEP and loads in the caller as // appropriate. @@ -261,10 +263,11 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr}; for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); - Value *Idx = GetElementPtrInst::Create( - STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i), Call); + auto *Idx = + IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i)); // TODO: Tell AA about the new values? - Args.push_back(new LoadInst(Idx, Idx->getName() + ".val", Call)); + Args.push_back(IRB.CreateLoad(STy->getElementType(i), Idx, + Idx->getName() + ".val")); ArgAttrVec.push_back(AttributeSet()); } } else if (!I->use_empty()) { @@ -294,13 +297,13 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); } // And create a GEP to extract those indices. - V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, - V->getName() + ".idx", Call); + V = IRB.CreateGEP(ArgIndex.first, V, Ops, V->getName() + ".idx"); Ops.clear(); } // Since we're replacing a load make sure we take the alignment // of the previous load. - LoadInst *newLoad = new LoadInst(V, V->getName() + ".val", Call); + LoadInst *newLoad = + IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val"); newLoad->setAlignment(OrigLoad->getAlignment()); // Transfer the AA info too. AAMDNodes AAInfo; @@ -476,9 +479,9 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, return NF; } -/// AllCallersPassInValidPointerForArgument - Return true if we can prove that -/// all callees pass in a valid pointer for the specified function argument. -static bool allCallersPassInValidPointerForArgument(Argument *Arg) { +/// Return true if we can prove that all callees pass in a valid pointer for the +/// specified function argument. +static bool allCallersPassValidPointerForArgument(Argument *Arg, Type *Ty) { Function *Callee = Arg->getParent(); const DataLayout &DL = Callee->getParent()->getDataLayout(); @@ -490,7 +493,7 @@ static bool allCallersPassInValidPointerForArgument(Argument *Arg) { CallSite CS(U); assert(CS && "Should only have direct calls!"); - if (!isDereferenceablePointer(CS.getArgument(ArgNo), DL)) + if (!isDereferenceablePointer(CS.getArgument(ArgNo), Ty, DL)) return false; } return true; @@ -563,8 +566,8 @@ static void markIndicesSafe(const IndicesVector &ToMark, /// This method limits promotion of aggregates to only promote up to three /// elements of the aggregate in order to avoid exploding the number of /// arguments passed in. -static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, - AAResults &AAR, unsigned MaxElements) { +static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR, + unsigned MaxElements) { using GEPIndicesSet = std::set<IndicesVector>; // Quick exit for unused arguments @@ -586,9 +589,6 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // // This set will contain all sets of indices that are loaded in the entry // block, and thus are safe to unconditionally load in the caller. - // - // This optimization is also safe for InAlloca parameters, because it verifies - // that the address isn't captured. GEPIndicesSet SafeToUnconditionallyLoad; // This set contains all the sets of indices that we are planning to promote. @@ -596,9 +596,28 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, GEPIndicesSet ToPromote; // If the pointer is always valid, any load with first index 0 is valid. - if (isByValOrInAlloca || allCallersPassInValidPointerForArgument(Arg)) + + if (ByValTy) SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); + // Whenever a new underlying type for the operand is found, make sure it's + // consistent with the GEPs and loads we've already seen and, if necessary, + // use it to see if all incoming pointers are valid (which implies the 0-index + // is safe). + Type *BaseTy = ByValTy; + auto UpdateBaseTy = [&](Type *NewBaseTy) { + if (BaseTy) + return BaseTy == NewBaseTy; + + BaseTy = NewBaseTy; + if (allCallersPassValidPointerForArgument(Arg, BaseTy)) { + assert(SafeToUnconditionallyLoad.empty()); + SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); + } + + return true; + }; + // First, iterate the entry block and mark loads of (geps of) arguments as // safe. BasicBlock &EntryBlock = Arg->getParent()->front(); @@ -621,6 +640,9 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // right away, can't promote this argument at all. return false; + if (!UpdateBaseTy(GEP->getSourceElementType())) + return false; + // Indices checked out, mark them as safe markIndicesSafe(Indices, SafeToUnconditionallyLoad); Indices.clear(); @@ -628,6 +650,11 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, } else if (V == Arg) { // Direct loads are equivalent to a GEP with a single 0 index. markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); + + if (BaseTy && LI->getType() != BaseTy) + return false; + + BaseTy = LI->getType(); } } @@ -645,6 +672,9 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, Loads.push_back(LI); // Direct loads are equivalent to a GEP with a zero index and then a load. Operands.push_back(0); + + if (!UpdateBaseTy(LI->getType())) + return false; } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UR)) { if (GEP->use_empty()) { // Dead GEP's cause trouble later. Just remove them if we run into @@ -653,10 +683,12 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // TODO: This runs the above loop over and over again for dead GEPs // Couldn't we just do increment the UI iterator earlier and erase the // use? - return isSafeToPromoteArgument(Arg, isByValOrInAlloca, AAR, - MaxElements); + return isSafeToPromoteArgument(Arg, ByValTy, AAR, MaxElements); } + if (!UpdateBaseTy(GEP->getSourceElementType())) + return false; + // Ensure that all of the indices are constants. for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e; ++i) @@ -853,6 +885,11 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, if (F->isVarArg()) return nullptr; + // Don't transform functions that receive inallocas, as the transformation may + // not be safe depending on calling convention. + if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca)) + return nullptr; + // First check: see if there are any pointer arguments! If not, quick exit. SmallVector<Argument *, 16> PointerArgs; for (Argument &I : F->args()) @@ -911,8 +948,7 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, // If this is a byval argument, and if the aggregate type is small, just // pass the elements, which is always safe, if the passed value is densely - // packed or if we can prove the padding bytes are never accessed. This does - // not apply to inalloca. + // packed or if we can prove the padding bytes are never accessed. bool isSafeToPromote = PtrArg->hasByValAttr() && (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg)); @@ -963,8 +999,9 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, } // Otherwise, see if we can promote the pointer to its value. - if (isSafeToPromoteArgument(PtrArg, PtrArg->hasByValOrInAllocaAttr(), AAR, - MaxElements)) + Type *ByValTy = + PtrArg->hasByValAttr() ? PtrArg->getParamByValType() : nullptr; + if (isSafeToPromoteArgument(PtrArg, ByValTy, AAR, MaxElements)) ArgsToPromote.insert(PtrArg); } @@ -1101,7 +1138,9 @@ bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { CallGraphNode *NewCalleeNode = CG.getOrInsertFunction(NewCS.getCalledFunction()); CallGraphNode *CallerNode = CG[Caller]; - CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); + CallerNode->replaceCallEdge(*cast<CallBase>(OldCS.getInstruction()), + *cast<CallBase>(NewCS.getInstruction()), + NewCalleeNode); }; const TargetTransformInfo &TTI = diff --git a/lib/Transforms/IPO/Attributor.cpp b/lib/Transforms/IPO/Attributor.cpp new file mode 100644 index 000000000000..2a52c6b9b4ad --- /dev/null +++ b/lib/Transforms/IPO/Attributor.cpp @@ -0,0 +1,1690 @@ +//===- Attributor.cpp - Module-wide attribute deduction -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an inter procedural pass that deduces and/or propagating +// attributes. This is done in an abstract interpretation style fixpoint +// iteration. See the Attributor.h file comment and the class descriptions in +// that file for more information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/Attributor.h" + +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> + +using namespace llvm; + +#define DEBUG_TYPE "attributor" + +STATISTIC(NumFnWithExactDefinition, + "Number of function with exact definitions"); +STATISTIC(NumFnWithoutExactDefinition, + "Number of function without exact definitions"); +STATISTIC(NumAttributesTimedOut, + "Number of abstract attributes timed out before fixpoint"); +STATISTIC(NumAttributesValidFixpoint, + "Number of abstract attributes in a valid fixpoint state"); +STATISTIC(NumAttributesManifested, + "Number of abstract attributes manifested in IR"); +STATISTIC(NumFnNoUnwind, "Number of functions marked nounwind"); + +STATISTIC(NumFnUniqueReturned, "Number of function with unique return"); +STATISTIC(NumFnKnownReturns, "Number of function with known return values"); +STATISTIC(NumFnArgumentReturned, + "Number of function arguments marked returned"); +STATISTIC(NumFnNoSync, "Number of functions marked nosync"); +STATISTIC(NumFnNoFree, "Number of functions marked nofree"); +STATISTIC(NumFnReturnedNonNull, + "Number of function return values marked nonnull"); +STATISTIC(NumFnArgumentNonNull, "Number of function arguments marked nonnull"); +STATISTIC(NumCSArgumentNonNull, "Number of call site arguments marked nonnull"); +STATISTIC(NumFnWillReturn, "Number of functions marked willreturn"); + +// TODO: Determine a good default value. +// +// In the LLVM-TS and SPEC2006, 32 seems to not induce compile time overheads +// (when run with the first 5 abstract attributes). The results also indicate +// that we never reach 32 iterations but always find a fixpoint sooner. +// +// This will become more evolved once we perform two interleaved fixpoint +// iterations: bottom-up and top-down. +static cl::opt<unsigned> + MaxFixpointIterations("attributor-max-iterations", cl::Hidden, + cl::desc("Maximal number of fixpoint iterations."), + cl::init(32)); + +static cl::opt<bool> DisableAttributor( + "attributor-disable", cl::Hidden, + cl::desc("Disable the attributor inter-procedural deduction pass."), + cl::init(true)); + +static cl::opt<bool> VerifyAttributor( + "attributor-verify", cl::Hidden, + cl::desc("Verify the Attributor deduction and " + "manifestation of attributes -- may issue false-positive errors"), + cl::init(false)); + +/// Logic operators for the change status enum class. +/// +///{ +ChangeStatus llvm::operator|(ChangeStatus l, ChangeStatus r) { + return l == ChangeStatus::CHANGED ? l : r; +} +ChangeStatus llvm::operator&(ChangeStatus l, ChangeStatus r) { + return l == ChangeStatus::UNCHANGED ? l : r; +} +///} + +/// Helper to adjust the statistics. +static void bookkeeping(AbstractAttribute::ManifestPosition MP, + const Attribute &Attr) { + if (!AreStatisticsEnabled()) + return; + + if (!Attr.isEnumAttribute()) + return; + switch (Attr.getKindAsEnum()) { + case Attribute::NoUnwind: + NumFnNoUnwind++; + return; + case Attribute::Returned: + NumFnArgumentReturned++; + return; + case Attribute::NoSync: + NumFnNoSync++; + break; + case Attribute::NoFree: + NumFnNoFree++; + break; + case Attribute::NonNull: + switch (MP) { + case AbstractAttribute::MP_RETURNED: + NumFnReturnedNonNull++; + break; + case AbstractAttribute::MP_ARGUMENT: + NumFnArgumentNonNull++; + break; + case AbstractAttribute::MP_CALL_SITE_ARGUMENT: + NumCSArgumentNonNull++; + break; + default: + break; + } + break; + case Attribute::WillReturn: + NumFnWillReturn++; + break; + default: + return; + } +} + +template <typename StateTy> +using followValueCB_t = std::function<bool(Value *, StateTy &State)>; +template <typename StateTy> +using visitValueCB_t = std::function<void(Value *, StateTy &State)>; + +/// Recursively visit all values that might become \p InitV at some point. This +/// will be done by looking through cast instructions, selects, phis, and calls +/// with the "returned" attribute. The callback \p FollowValueCB is asked before +/// a potential origin value is looked at. If no \p FollowValueCB is passed, a +/// default one is used that will make sure we visit every value only once. Once +/// we cannot look through the value any further, the callback \p VisitValueCB +/// is invoked and passed the current value and the \p State. To limit how much +/// effort is invested, we will never visit more than \p MaxValues values. +template <typename StateTy> +static bool genericValueTraversal( + Value *InitV, StateTy &State, visitValueCB_t<StateTy> &VisitValueCB, + followValueCB_t<StateTy> *FollowValueCB = nullptr, int MaxValues = 8) { + + SmallPtrSet<Value *, 16> Visited; + followValueCB_t<bool> DefaultFollowValueCB = [&](Value *Val, bool &) { + return Visited.insert(Val).second; + }; + + if (!FollowValueCB) + FollowValueCB = &DefaultFollowValueCB; + + SmallVector<Value *, 16> Worklist; + Worklist.push_back(InitV); + + int Iteration = 0; + do { + Value *V = Worklist.pop_back_val(); + + // Check if we should process the current value. To prevent endless + // recursion keep a record of the values we followed! + if (!(*FollowValueCB)(V, State)) + continue; + + // Make sure we limit the compile time for complex expressions. + if (Iteration++ >= MaxValues) + return false; + + // Explicitly look through calls with a "returned" attribute if we do + // not have a pointer as stripPointerCasts only works on them. + if (V->getType()->isPointerTy()) { + V = V->stripPointerCasts(); + } else { + CallSite CS(V); + if (CS && CS.getCalledFunction()) { + Value *NewV = nullptr; + for (Argument &Arg : CS.getCalledFunction()->args()) + if (Arg.hasReturnedAttr()) { + NewV = CS.getArgOperand(Arg.getArgNo()); + break; + } + if (NewV) { + Worklist.push_back(NewV); + continue; + } + } + } + + // Look through select instructions, visit both potential values. + if (auto *SI = dyn_cast<SelectInst>(V)) { + Worklist.push_back(SI->getTrueValue()); + Worklist.push_back(SI->getFalseValue()); + continue; + } + + // Look through phi nodes, visit all operands. + if (auto *PHI = dyn_cast<PHINode>(V)) { + Worklist.append(PHI->op_begin(), PHI->op_end()); + continue; + } + + // Once a leaf is reached we inform the user through the callback. + VisitValueCB(V, State); + } while (!Worklist.empty()); + + // All values have been visited. + return true; +} + +/// Helper to identify the correct offset into an attribute list. +static unsigned getAttrIndex(AbstractAttribute::ManifestPosition MP, + unsigned ArgNo = 0) { + switch (MP) { + case AbstractAttribute::MP_ARGUMENT: + case AbstractAttribute::MP_CALL_SITE_ARGUMENT: + return ArgNo + AttributeList::FirstArgIndex; + case AbstractAttribute::MP_FUNCTION: + return AttributeList::FunctionIndex; + case AbstractAttribute::MP_RETURNED: + return AttributeList::ReturnIndex; + } + llvm_unreachable("Unknown manifest position!"); +} + +/// Return true if \p New is equal or worse than \p Old. +static bool isEqualOrWorse(const Attribute &New, const Attribute &Old) { + if (!Old.isIntAttribute()) + return true; + + return Old.getValueAsInt() >= New.getValueAsInt(); +} + +/// Return true if the information provided by \p Attr was added to the +/// attribute list \p Attrs. This is only the case if it was not already present +/// in \p Attrs at the position describe by \p MP and \p ArgNo. +static bool addIfNotExistent(LLVMContext &Ctx, const Attribute &Attr, + AttributeList &Attrs, + AbstractAttribute::ManifestPosition MP, + unsigned ArgNo = 0) { + unsigned AttrIdx = getAttrIndex(MP, ArgNo); + + if (Attr.isEnumAttribute()) { + Attribute::AttrKind Kind = Attr.getKindAsEnum(); + if (Attrs.hasAttribute(AttrIdx, Kind)) + if (isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + return false; + Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + return true; + } + if (Attr.isStringAttribute()) { + StringRef Kind = Attr.getKindAsString(); + if (Attrs.hasAttribute(AttrIdx, Kind)) + if (isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + return false; + Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + return true; + } + + llvm_unreachable("Expected enum or string attribute!"); +} + +ChangeStatus AbstractAttribute::update(Attributor &A) { + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + if (getState().isAtFixpoint()) + return HasChanged; + + LLVM_DEBUG(dbgs() << "[Attributor] Update: " << *this << "\n"); + + HasChanged = updateImpl(A); + + LLVM_DEBUG(dbgs() << "[Attributor] Update " << HasChanged << " " << *this + << "\n"); + + return HasChanged; +} + +ChangeStatus AbstractAttribute::manifest(Attributor &A) { + assert(getState().isValidState() && + "Attempted to manifest an invalid state!"); + assert(getAssociatedValue() && + "Attempted to manifest an attribute without associated value!"); + + ChangeStatus HasChanged = ChangeStatus::UNCHANGED; + SmallVector<Attribute, 4> DeducedAttrs; + getDeducedAttributes(DeducedAttrs); + + Function &ScopeFn = getAnchorScope(); + LLVMContext &Ctx = ScopeFn.getContext(); + ManifestPosition MP = getManifestPosition(); + + AttributeList Attrs; + SmallVector<unsigned, 4> ArgNos; + + // In the following some generic code that will manifest attributes in + // DeducedAttrs if they improve the current IR. Due to the different + // annotation positions we use the underlying AttributeList interface. + // Note that MP_CALL_SITE_ARGUMENT can annotate multiple locations. + + switch (MP) { + case MP_ARGUMENT: + ArgNos.push_back(cast<Argument>(getAssociatedValue())->getArgNo()); + Attrs = ScopeFn.getAttributes(); + break; + case MP_FUNCTION: + case MP_RETURNED: + ArgNos.push_back(0); + Attrs = ScopeFn.getAttributes(); + break; + case MP_CALL_SITE_ARGUMENT: { + CallSite CS(&getAnchoredValue()); + for (unsigned u = 0, e = CS.getNumArgOperands(); u != e; u++) + if (CS.getArgOperand(u) == getAssociatedValue()) + ArgNos.push_back(u); + Attrs = CS.getAttributes(); + } + } + + for (const Attribute &Attr : DeducedAttrs) { + for (unsigned ArgNo : ArgNos) { + if (!addIfNotExistent(Ctx, Attr, Attrs, MP, ArgNo)) + continue; + + HasChanged = ChangeStatus::CHANGED; + bookkeeping(MP, Attr); + } + } + + if (HasChanged == ChangeStatus::UNCHANGED) + return HasChanged; + + switch (MP) { + case MP_ARGUMENT: + case MP_FUNCTION: + case MP_RETURNED: + ScopeFn.setAttributes(Attrs); + break; + case MP_CALL_SITE_ARGUMENT: + CallSite(&getAnchoredValue()).setAttributes(Attrs); + } + + return HasChanged; +} + +Function &AbstractAttribute::getAnchorScope() { + Value &V = getAnchoredValue(); + if (isa<Function>(V)) + return cast<Function>(V); + if (isa<Argument>(V)) + return *cast<Argument>(V).getParent(); + if (isa<Instruction>(V)) + return *cast<Instruction>(V).getFunction(); + llvm_unreachable("No scope for anchored value found!"); +} + +const Function &AbstractAttribute::getAnchorScope() const { + return const_cast<AbstractAttribute *>(this)->getAnchorScope(); +} + +/// -----------------------NoUnwind Function Attribute-------------------------- + +struct AANoUnwindFunction : AANoUnwind, BooleanState { + + AANoUnwindFunction(Function &F, InformationCache &InfoCache) + : AANoUnwind(F, InfoCache) {} + + /// See AbstractAttribute::getState() + /// { + AbstractState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + /// } + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } + + const std::string getAsStr() const override { + return getAssumed() ? "nounwind" : "may-unwind"; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AANoUnwind::isAssumedNoUnwind(). + bool isAssumedNoUnwind() const override { return getAssumed(); } + + /// See AANoUnwind::isKnownNoUnwind(). + bool isKnownNoUnwind() const override { return getKnown(); } +}; + +ChangeStatus AANoUnwindFunction::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + + // The map from instruction opcodes to those instructions in the function. + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + auto Opcodes = { + (unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call, (unsigned)Instruction::CleanupRet, + (unsigned)Instruction::CatchSwitch, (unsigned)Instruction::Resume}; + + for (unsigned Opcode : Opcodes) { + for (Instruction *I : OpcodeInstMap[Opcode]) { + if (!I->mayThrow()) + continue; + + auto *NoUnwindAA = A.getAAFor<AANoUnwind>(*this, *I); + + if (!NoUnwindAA || !NoUnwindAA->isAssumedNoUnwind()) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + } + } + return ChangeStatus::UNCHANGED; +} + +/// --------------------- Function Return Values ------------------------------- + +/// "Attribute" that collects all potential returned values and the return +/// instructions that they arise from. +/// +/// If there is a unique returned value R, the manifest method will: +/// - mark R with the "returned" attribute, if R is an argument. +class AAReturnedValuesImpl final : public AAReturnedValues, AbstractState { + + /// Mapping of values potentially returned by the associated function to the + /// return instructions that might return them. + DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> ReturnedValues; + + /// State flags + /// + ///{ + bool IsFixed; + bool IsValidState; + bool HasOverdefinedReturnedCalls; + ///} + + /// Collect values that could become \p V in the set \p Values, each mapped to + /// \p ReturnInsts. + void collectValuesRecursively( + Attributor &A, Value *V, SmallPtrSetImpl<ReturnInst *> &ReturnInsts, + DenseMap<Value *, SmallPtrSet<ReturnInst *, 2>> &Values) { + + visitValueCB_t<bool> VisitValueCB = [&](Value *Val, bool &) { + assert(!isa<Instruction>(Val) || + &getAnchorScope() == cast<Instruction>(Val)->getFunction()); + Values[Val].insert(ReturnInsts.begin(), ReturnInsts.end()); + }; + + bool UnusedBool; + bool Success = genericValueTraversal(V, UnusedBool, VisitValueCB); + + // If we did abort the above traversal we haven't see all the values. + // Consequently, we cannot know if the information we would derive is + // accurate so we give up early. + if (!Success) + indicatePessimisticFixpoint(); + } + +public: + /// See AbstractAttribute::AbstractAttribute(...). + AAReturnedValuesImpl(Function &F, InformationCache &InfoCache) + : AAReturnedValues(F, InfoCache) { + // We do not have an associated argument yet. + AssociatedVal = nullptr; + } + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + // Reset the state. + AssociatedVal = nullptr; + IsFixed = false; + IsValidState = true; + HasOverdefinedReturnedCalls = false; + ReturnedValues.clear(); + + Function &F = cast<Function>(getAnchoredValue()); + + // The map from instruction opcodes to those instructions in the function. + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + + // Look through all arguments, if one is marked as returned we are done. + for (Argument &Arg : F.args()) { + if (Arg.hasReturnedAttr()) { + + auto &ReturnInstSet = ReturnedValues[&Arg]; + for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) + ReturnInstSet.insert(cast<ReturnInst>(RI)); + + indicateOptimisticFixpoint(); + return; + } + } + + // If no argument was marked as returned we look at all return instructions + // and collect potentially returned values. + for (Instruction *RI : OpcodeInstMap[Instruction::Ret]) { + SmallPtrSet<ReturnInst *, 1> RISet({cast<ReturnInst>(RI)}); + collectValuesRecursively(A, cast<ReturnInst>(RI)->getReturnValue(), RISet, + ReturnedValues); + } + } + + /// See AbstractAttribute::manifest(...). + ChangeStatus manifest(Attributor &A) override; + + /// See AbstractAttribute::getState(...). + AbstractState &getState() override { return *this; } + + /// See AbstractAttribute::getState(...). + const AbstractState &getState() const override { return *this; } + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; } + + /// See AbstractAttribute::updateImpl(Attributor &A). + ChangeStatus updateImpl(Attributor &A) override; + + /// Return the number of potential return values, -1 if unknown. + size_t getNumReturnValues() const { + return isValidState() ? ReturnedValues.size() : -1; + } + + /// Return an assumed unique return value if a single candidate is found. If + /// there cannot be one, return a nullptr. If it is not clear yet, return the + /// Optional::NoneType. + Optional<Value *> getAssumedUniqueReturnValue() const; + + /// See AbstractState::checkForallReturnedValues(...). + bool + checkForallReturnedValues(std::function<bool(Value &)> &Pred) const override; + + /// Pretty print the attribute similar to the IR representation. + const std::string getAsStr() const override; + + /// See AbstractState::isAtFixpoint(). + bool isAtFixpoint() const override { return IsFixed; } + + /// See AbstractState::isValidState(). + bool isValidState() const override { return IsValidState; } + + /// See AbstractState::indicateOptimisticFixpoint(...). + void indicateOptimisticFixpoint() override { + IsFixed = true; + IsValidState &= true; + } + void indicatePessimisticFixpoint() override { + IsFixed = true; + IsValidState = false; + } +}; + +ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + // Bookkeeping. + assert(isValidState()); + NumFnKnownReturns++; + + // Check if we have an assumed unique return value that we could manifest. + Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(); + + if (!UniqueRV.hasValue() || !UniqueRV.getValue()) + return Changed; + + // Bookkeeping. + NumFnUniqueReturned++; + + // If the assumed unique return value is an argument, annotate it. + if (auto *UniqueRVArg = dyn_cast<Argument>(UniqueRV.getValue())) { + AssociatedVal = UniqueRVArg; + Changed = AbstractAttribute::manifest(A) | Changed; + } + + return Changed; +} + +const std::string AAReturnedValuesImpl::getAsStr() const { + return (isAtFixpoint() ? "returns(#" : "may-return(#") + + (isValidState() ? std::to_string(getNumReturnValues()) : "?") + ")"; +} + +Optional<Value *> AAReturnedValuesImpl::getAssumedUniqueReturnValue() const { + // If checkForallReturnedValues provides a unique value, ignoring potential + // undef values that can also be present, it is assumed to be the actual + // return value and forwarded to the caller of this method. If there are + // multiple, a nullptr is returned indicating there cannot be a unique + // returned value. + Optional<Value *> UniqueRV; + + std::function<bool(Value &)> Pred = [&](Value &RV) -> bool { + // If we found a second returned value and neither the current nor the saved + // one is an undef, there is no unique returned value. Undefs are special + // since we can pretend they have any value. + if (UniqueRV.hasValue() && UniqueRV != &RV && + !(isa<UndefValue>(RV) || isa<UndefValue>(UniqueRV.getValue()))) { + UniqueRV = nullptr; + return false; + } + + // Do not overwrite a value with an undef. + if (!UniqueRV.hasValue() || !isa<UndefValue>(RV)) + UniqueRV = &RV; + + return true; + }; + + if (!checkForallReturnedValues(Pred)) + UniqueRV = nullptr; + + return UniqueRV; +} + +bool AAReturnedValuesImpl::checkForallReturnedValues( + std::function<bool(Value &)> &Pred) const { + if (!isValidState()) + return false; + + // Check all returned values but ignore call sites as long as we have not + // encountered an overdefined one during an update. + for (auto &It : ReturnedValues) { + Value *RV = It.first; + + ImmutableCallSite ICS(RV); + if (ICS && !HasOverdefinedReturnedCalls) + continue; + + if (!Pred(*RV)) + return false; + } + + return true; +} + +ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { + + // Check if we know of any values returned by the associated function, + // if not, we are done. + if (getNumReturnValues() == 0) { + indicateOptimisticFixpoint(); + return ChangeStatus::UNCHANGED; + } + + // Check if any of the returned values is a call site we can refine. + decltype(ReturnedValues) AddRVs; + bool HasCallSite = false; + + // Look at all returned call sites. + for (auto &It : ReturnedValues) { + SmallPtrSet<ReturnInst *, 2> &ReturnInsts = It.second; + Value *RV = It.first; + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Potentially returned value " << *RV + << "\n"); + + // Only call sites can change during an update, ignore the rest. + CallSite RetCS(RV); + if (!RetCS) + continue; + + // For now, any call site we see will prevent us from directly fixing the + // state. However, if the information on the callees is fixed, the call + // sites will be removed and we will fix the information for this state. + HasCallSite = true; + + // Try to find a assumed unique return value for the called function. + auto *RetCSAA = A.getAAFor<AAReturnedValuesImpl>(*this, *RV); + if (!RetCSAA) { + HasOverdefinedReturnedCalls = true; + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site (" << *RV + << ") with " << (RetCSAA ? "invalid" : "no") + << " associated state\n"); + continue; + } + + // Try to find a assumed unique return value for the called function. + Optional<Value *> AssumedUniqueRV = RetCSAA->getAssumedUniqueReturnValue(); + + // If no assumed unique return value was found due to the lack of + // candidates, we may need to resolve more calls (through more update + // iterations) or the called function will not return. Either way, we simply + // stick with the call sites as return values. Because there were not + // multiple possibilities, we do not treat it as overdefined. + if (!AssumedUniqueRV.hasValue()) + continue; + + // If multiple, non-refinable values were found, there cannot be a unique + // return value for the called function. The returned call is overdefined! + if (!AssumedUniqueRV.getValue()) { + HasOverdefinedReturnedCalls = true; + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Returned call site has multiple " + "potentially returned values\n"); + continue; + } + + LLVM_DEBUG({ + bool UniqueRVIsKnown = RetCSAA->isAtFixpoint(); + dbgs() << "[AAReturnedValues] Returned call site " + << (UniqueRVIsKnown ? "known" : "assumed") + << " unique return value: " << *AssumedUniqueRV << "\n"; + }); + + // The assumed unique return value. + Value *AssumedRetVal = AssumedUniqueRV.getValue(); + + // If the assumed unique return value is an argument, lookup the matching + // call site operand and recursively collect new returned values. + // If it is not an argument, it is just put into the set of returned values + // as we would have already looked through casts, phis, and similar values. + if (Argument *AssumedRetArg = dyn_cast<Argument>(AssumedRetVal)) + collectValuesRecursively(A, + RetCS.getArgOperand(AssumedRetArg->getArgNo()), + ReturnInsts, AddRVs); + else + AddRVs[AssumedRetVal].insert(ReturnInsts.begin(), ReturnInsts.end()); + } + + // Keep track of any change to trigger updates on dependent attributes. + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + for (auto &It : AddRVs) { + assert(!It.second.empty() && "Entry does not add anything."); + auto &ReturnInsts = ReturnedValues[It.first]; + for (ReturnInst *RI : It.second) + if (ReturnInsts.insert(RI).second) { + LLVM_DEBUG(dbgs() << "[AAReturnedValues] Add new returned value " + << *It.first << " => " << *RI << "\n"); + Changed = ChangeStatus::CHANGED; + } + } + + // If there is no call site in the returned values we are done. + if (!HasCallSite) { + indicateOptimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + return Changed; +} + +/// ------------------------ NoSync Function Attribute ------------------------- + +struct AANoSyncFunction : AANoSync, BooleanState { + + AANoSyncFunction(Function &F, InformationCache &InfoCache) + : AANoSync(F, InfoCache) {} + + /// See AbstractAttribute::getState() + /// { + AbstractState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + /// } + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } + + const std::string getAsStr() const override { + return getAssumed() ? "nosync" : "may-sync"; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AANoSync::isAssumedNoSync() + bool isAssumedNoSync() const override { return getAssumed(); } + + /// See AANoSync::isKnownNoSync() + bool isKnownNoSync() const override { return getKnown(); } + + /// Helper function used to determine whether an instruction is non-relaxed + /// atomic. In other words, if an atomic instruction does not have unordered + /// or monotonic ordering + static bool isNonRelaxedAtomic(Instruction *I); + + /// Helper function used to determine whether an instruction is volatile. + static bool isVolatile(Instruction *I); + + /// Helper function uset to check if intrinsic is volatile (memcpy, memmove, + /// memset). + static bool isNoSyncIntrinsic(Instruction *I); +}; + +bool AANoSyncFunction::isNonRelaxedAtomic(Instruction *I) { + if (!I->isAtomic()) + return false; + + AtomicOrdering Ordering; + switch (I->getOpcode()) { + case Instruction::AtomicRMW: + Ordering = cast<AtomicRMWInst>(I)->getOrdering(); + break; + case Instruction::Store: + Ordering = cast<StoreInst>(I)->getOrdering(); + break; + case Instruction::Load: + Ordering = cast<LoadInst>(I)->getOrdering(); + break; + case Instruction::Fence: { + auto *FI = cast<FenceInst>(I); + if (FI->getSyncScopeID() == SyncScope::SingleThread) + return false; + Ordering = FI->getOrdering(); + break; + } + case Instruction::AtomicCmpXchg: { + AtomicOrdering Success = cast<AtomicCmpXchgInst>(I)->getSuccessOrdering(); + AtomicOrdering Failure = cast<AtomicCmpXchgInst>(I)->getFailureOrdering(); + // Only if both are relaxed, than it can be treated as relaxed. + // Otherwise it is non-relaxed. + if (Success != AtomicOrdering::Unordered && + Success != AtomicOrdering::Monotonic) + return true; + if (Failure != AtomicOrdering::Unordered && + Failure != AtomicOrdering::Monotonic) + return true; + return false; + } + default: + llvm_unreachable( + "New atomic operations need to be known in the attributor."); + } + + // Relaxed. + if (Ordering == AtomicOrdering::Unordered || + Ordering == AtomicOrdering::Monotonic) + return false; + return true; +} + +/// Checks if an intrinsic is nosync. Currently only checks mem* intrinsics. +/// FIXME: We should ipmrove the handling of intrinsics. +bool AANoSyncFunction::isNoSyncIntrinsic(Instruction *I) { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + /// Element wise atomic memory intrinsics are can only be unordered, + /// therefore nosync. + case Intrinsic::memset_element_unordered_atomic: + case Intrinsic::memmove_element_unordered_atomic: + case Intrinsic::memcpy_element_unordered_atomic: + return true; + case Intrinsic::memset: + case Intrinsic::memmove: + case Intrinsic::memcpy: + if (!cast<MemIntrinsic>(II)->isVolatile()) + return true; + return false; + default: + return false; + } + } + return false; +} + +bool AANoSyncFunction::isVolatile(Instruction *I) { + assert(!ImmutableCallSite(I) && !isa<CallBase>(I) && + "Calls should not be checked here"); + + switch (I->getOpcode()) { + case Instruction::AtomicRMW: + return cast<AtomicRMWInst>(I)->isVolatile(); + case Instruction::Store: + return cast<StoreInst>(I)->isVolatile(); + case Instruction::Load: + return cast<LoadInst>(I)->isVolatile(); + case Instruction::AtomicCmpXchg: + return cast<AtomicCmpXchgInst>(I)->isVolatile(); + default: + return false; + } +} + +ChangeStatus AANoSyncFunction::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + + /// We are looking for volatile instructions or Non-Relaxed atomics. + /// FIXME: We should ipmrove the handling of intrinsics. + for (Instruction *I : InfoCache.getReadOrWriteInstsForFunction(F)) { + ImmutableCallSite ICS(I); + auto *NoSyncAA = A.getAAFor<AANoSyncFunction>(*this, *I); + + if (isa<IntrinsicInst>(I) && isNoSyncIntrinsic(I)) + continue; + + if (ICS && (!NoSyncAA || !NoSyncAA->isAssumedNoSync()) && + !ICS.hasFnAttr(Attribute::NoSync)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + if (ICS) + continue; + + if (!isVolatile(I) && !isNonRelaxedAtomic(I)) + continue; + + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + auto Opcodes = {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call}; + + for (unsigned Opcode : Opcodes) { + for (Instruction *I : OpcodeInstMap[Opcode]) { + // At this point we handled all read/write effects and they are all + // nosync, so they can be skipped. + if (I->mayReadOrWriteMemory()) + continue; + + ImmutableCallSite ICS(I); + + // non-convergent and readnone imply nosync. + if (!ICS.isConvergent()) + continue; + + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + } + + return ChangeStatus::UNCHANGED; +} + +/// ------------------------ No-Free Attributes ---------------------------- + +struct AANoFreeFunction : AbstractAttribute, BooleanState { + + /// See AbstractAttribute::AbstractAttribute(...). + AANoFreeFunction(Function &F, InformationCache &InfoCache) + : AbstractAttribute(F, InfoCache) {} + + /// See AbstractAttribute::getState() + ///{ + AbstractState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + ///} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_FUNCTION; } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "nofree" : "may-free"; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AbstractAttribute::getAttrKind(). + Attribute::AttrKind getAttrKind() const override { return ID; } + + /// Return true if "nofree" is assumed. + bool isAssumedNoFree() const { return getAssumed(); } + + /// Return true if "nofree" is known. + bool isKnownNoFree() const { return getKnown(); } + + /// The identifier used by the Attributor for this class of attributes. + static constexpr Attribute::AttrKind ID = Attribute::NoFree; +}; + +ChangeStatus AANoFreeFunction::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + + // The map from instruction opcodes to those instructions in the function. + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + + for (unsigned Opcode : + {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call}) { + for (Instruction *I : OpcodeInstMap[Opcode]) { + + auto ICS = ImmutableCallSite(I); + auto *NoFreeAA = A.getAAFor<AANoFreeFunction>(*this, *I); + + if ((!NoFreeAA || !NoFreeAA->isAssumedNoFree()) && + !ICS.hasFnAttr(Attribute::NoFree)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + } + } + return ChangeStatus::UNCHANGED; +} + +/// ------------------------ NonNull Argument Attribute ------------------------ +struct AANonNullImpl : AANonNull, BooleanState { + + AANonNullImpl(Value &V, InformationCache &InfoCache) + : AANonNull(V, InfoCache) {} + + AANonNullImpl(Value *AssociatedVal, Value &AnchoredValue, + InformationCache &InfoCache) + : AANonNull(AssociatedVal, AnchoredValue, InfoCache) {} + + /// See AbstractAttribute::getState() + /// { + AbstractState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + /// } + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + return getAssumed() ? "nonnull" : "may-null"; + } + + /// See AANonNull::isAssumedNonNull(). + bool isAssumedNonNull() const override { return getAssumed(); } + + /// See AANonNull::isKnownNonNull(). + bool isKnownNonNull() const override { return getKnown(); } + + /// Generate a predicate that checks if a given value is assumed nonnull. + /// The generated function returns true if a value satisfies any of + /// following conditions. + /// (i) A value is known nonZero(=nonnull). + /// (ii) A value is associated with AANonNull and its isAssumedNonNull() is + /// true. + std::function<bool(Value &)> generatePredicate(Attributor &); +}; + +std::function<bool(Value &)> AANonNullImpl::generatePredicate(Attributor &A) { + // FIXME: The `AAReturnedValues` should provide the predicate with the + // `ReturnInst` vector as well such that we can use the control flow sensitive + // version of `isKnownNonZero`. This should fix `test11` in + // `test/Transforms/FunctionAttrs/nonnull.ll` + + std::function<bool(Value &)> Pred = [&](Value &RV) -> bool { + if (isKnownNonZero(&RV, getAnchorScope().getParent()->getDataLayout())) + return true; + + auto *NonNullAA = A.getAAFor<AANonNull>(*this, RV); + + ImmutableCallSite ICS(&RV); + + if ((!NonNullAA || !NonNullAA->isAssumedNonNull()) && + (!ICS || !ICS.hasRetAttr(Attribute::NonNull))) + return false; + + return true; + }; + + return Pred; +} + +/// NonNull attribute for function return value. +struct AANonNullReturned : AANonNullImpl { + + AANonNullReturned(Function &F, InformationCache &InfoCache) + : AANonNullImpl(F, InfoCache) {} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_RETURNED; } + + /// See AbstractAttriubute::initialize(...). + void initialize(Attributor &A) override { + Function &F = getAnchorScope(); + + // Already nonnull. + if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex, + Attribute::NonNull)) + indicateOptimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + +ChangeStatus AANonNullReturned::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + + auto *AARetVal = A.getAAFor<AAReturnedValues>(*this, F); + if (!AARetVal) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + std::function<bool(Value &)> Pred = this->generatePredicate(A); + if (!AARetVal->checkForallReturnedValues(Pred)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + return ChangeStatus::UNCHANGED; +} + +/// NonNull attribute for function argument. +struct AANonNullArgument : AANonNullImpl { + + AANonNullArgument(Argument &A, InformationCache &InfoCache) + : AANonNullImpl(A, InfoCache) {} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; } + + /// See AbstractAttriubute::initialize(...). + void initialize(Attributor &A) override { + Argument *Arg = cast<Argument>(getAssociatedValue()); + if (Arg->hasNonNullAttr()) + indicateOptimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + +/// NonNull attribute for a call site argument. +struct AANonNullCallSiteArgument : AANonNullImpl { + + /// See AANonNullImpl::AANonNullImpl(...). + AANonNullCallSiteArgument(CallSite CS, unsigned ArgNo, + InformationCache &InfoCache) + : AANonNullImpl(CS.getArgOperand(ArgNo), *CS.getInstruction(), InfoCache), + ArgNo(ArgNo) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + CallSite CS(&getAnchoredValue()); + if (isKnownNonZero(getAssociatedValue(), + getAnchorScope().getParent()->getDataLayout()) || + CS.paramHasAttr(ArgNo, getAttrKind())) + indicateOptimisticFixpoint(); + } + + /// See AbstractAttribute::updateImpl(Attributor &A). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { + return MP_CALL_SITE_ARGUMENT; + }; + + // Return argument index of associated value. + int getArgNo() const { return ArgNo; } + +private: + unsigned ArgNo; +}; +ChangeStatus AANonNullArgument::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + Argument &Arg = cast<Argument>(getAnchoredValue()); + + unsigned ArgNo = Arg.getArgNo(); + + // Callback function + std::function<bool(CallSite)> CallSiteCheck = [&](CallSite CS) { + assert(CS && "Sanity check: Call site was not initialized properly!"); + + auto *NonNullAA = A.getAAFor<AANonNull>(*this, *CS.getInstruction(), ArgNo); + + // Check that NonNullAA is AANonNullCallSiteArgument. + if (NonNullAA) { + ImmutableCallSite ICS(&NonNullAA->getAnchoredValue()); + if (ICS && CS.getInstruction() == ICS.getInstruction()) + return NonNullAA->isAssumedNonNull(); + return false; + } + + if (CS.paramHasAttr(ArgNo, Attribute::NonNull)) + return true; + + Value *V = CS.getArgOperand(ArgNo); + if (isKnownNonZero(V, getAnchorScope().getParent()->getDataLayout())) + return true; + + return false; + }; + if (!A.checkForAllCallSites(F, CallSiteCheck, true)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + return ChangeStatus::UNCHANGED; +} + +ChangeStatus AANonNullCallSiteArgument::updateImpl(Attributor &A) { + // NOTE: Never look at the argument of the callee in this method. + // If we do this, "nonnull" is always deduced because of the assumption. + + Value &V = *getAssociatedValue(); + + auto *NonNullAA = A.getAAFor<AANonNull>(*this, V); + + if (!NonNullAA || !NonNullAA->isAssumedNonNull()) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + return ChangeStatus::UNCHANGED; +} + +/// ------------------------ Will-Return Attributes ---------------------------- + +struct AAWillReturnImpl : public AAWillReturn, BooleanState { + + /// See AbstractAttribute::AbstractAttribute(...). + AAWillReturnImpl(Function &F, InformationCache &InfoCache) + : AAWillReturn(F, InfoCache) {} + + /// See AAWillReturn::isKnownWillReturn(). + bool isKnownWillReturn() const override { return getKnown(); } + + /// See AAWillReturn::isAssumedWillReturn(). + bool isAssumedWillReturn() const override { return getAssumed(); } + + /// See AbstractAttribute::getState(...). + AbstractState &getState() override { return *this; } + + /// See AbstractAttribute::getState(...). + const AbstractState &getState() const override { return *this; } + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + return getAssumed() ? "willreturn" : "may-noreturn"; + } +}; + +struct AAWillReturnFunction final : AAWillReturnImpl { + + /// See AbstractAttribute::AbstractAttribute(...). + AAWillReturnFunction(Function &F, InformationCache &InfoCache) + : AAWillReturnImpl(F, InfoCache) {} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { + return MP_FUNCTION; + } + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override; + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + +// Helper function that checks whether a function has any cycle. +// TODO: Replace with more efficent code +bool containsCycle(Function &F) { + SmallPtrSet<BasicBlock *, 32> Visited; + + // Traverse BB by dfs and check whether successor is already visited. + for (BasicBlock *BB : depth_first(&F)) { + Visited.insert(BB); + for (auto *SuccBB : successors(BB)) { + if (Visited.count(SuccBB)) + return true; + } + } + return false; +} + +// Helper function that checks the function have a loop which might become an +// endless loop +// FIXME: Any cycle is regarded as endless loop for now. +// We have to allow some patterns. +bool containsPossiblyEndlessLoop(Function &F) { return containsCycle(F); } + +void AAWillReturnFunction::initialize(Attributor &A) { + Function &F = getAnchorScope(); + + if (containsPossiblyEndlessLoop(F)) + indicatePessimisticFixpoint(); +} + +ChangeStatus AAWillReturnFunction::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + + // The map from instruction opcodes to those instructions in the function. + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(F); + + for (unsigned Opcode : + {(unsigned)Instruction::Invoke, (unsigned)Instruction::CallBr, + (unsigned)Instruction::Call}) { + for (Instruction *I : OpcodeInstMap[Opcode]) { + auto ICS = ImmutableCallSite(I); + + if (ICS.hasFnAttr(Attribute::WillReturn)) + continue; + + auto *WillReturnAA = A.getAAFor<AAWillReturn>(*this, *I); + if (!WillReturnAA || !WillReturnAA->isAssumedWillReturn()) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + auto *NoRecurseAA = A.getAAFor<AANoRecurse>(*this, *I); + + // FIXME: (i) Prohibit any recursion for now. + // (ii) AANoRecurse isn't implemented yet so currently any call is + // regarded as having recursion. + // Code below should be + // if ((!NoRecurseAA || !NoRecurseAA->isAssumedNoRecurse()) && + if (!NoRecurseAA && !ICS.hasFnAttr(Attribute::NoRecurse)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + } + } + + return ChangeStatus::UNCHANGED; +} + +/// ---------------------------------------------------------------------------- +/// Attributor +/// ---------------------------------------------------------------------------- + +bool Attributor::checkForAllCallSites(Function &F, + std::function<bool(CallSite)> &Pred, + bool RequireAllCallSites) { + // We can try to determine information from + // the call sites. However, this is only possible all call sites are known, + // hence the function has internal linkage. + if (RequireAllCallSites && !F.hasInternalLinkage()) { + LLVM_DEBUG( + dbgs() + << "Attributor: Function " << F.getName() + << " has no internal linkage, hence not all call sites are known\n"); + return false; + } + + for (const Use &U : F.uses()) { + + CallSite CS(U.getUser()); + if (!CS || !CS.isCallee(&U) || !CS.getCaller()->hasExactDefinition()) { + if (!RequireAllCallSites) + continue; + + LLVM_DEBUG(dbgs() << "Attributor: User " << *U.getUser() + << " is an invalid use of " << F.getName() << "\n"); + return false; + } + + if (Pred(CS)) + continue; + + LLVM_DEBUG(dbgs() << "Attributor: Call site callback failed for " + << *CS.getInstruction() << "\n"); + return false; + } + + return true; +} + +ChangeStatus Attributor::run() { + // Initialize all abstract attributes. + for (AbstractAttribute *AA : AllAbstractAttributes) + AA->initialize(*this); + + LLVM_DEBUG(dbgs() << "[Attributor] Identified and initialized " + << AllAbstractAttributes.size() + << " abstract attributes.\n"); + + // Now that all abstract attributes are collected and initialized we start + // the abstract analysis. + + unsigned IterationCounter = 1; + + SmallVector<AbstractAttribute *, 64> ChangedAAs; + SetVector<AbstractAttribute *> Worklist; + Worklist.insert(AllAbstractAttributes.begin(), AllAbstractAttributes.end()); + + do { + LLVM_DEBUG(dbgs() << "\n\n[Attributor] #Iteration: " << IterationCounter + << ", Worklist size: " << Worklist.size() << "\n"); + + // Add all abstract attributes that are potentially dependent on one that + // changed to the work list. + for (AbstractAttribute *ChangedAA : ChangedAAs) { + auto &QuerriedAAs = QueryMap[ChangedAA]; + Worklist.insert(QuerriedAAs.begin(), QuerriedAAs.end()); + } + + // Reset the changed set. + ChangedAAs.clear(); + + // Update all abstract attribute in the work list and record the ones that + // changed. + for (AbstractAttribute *AA : Worklist) + if (AA->update(*this) == ChangeStatus::CHANGED) + ChangedAAs.push_back(AA); + + // Reset the work list and repopulate with the changed abstract attributes. + // Note that dependent ones are added above. + Worklist.clear(); + Worklist.insert(ChangedAAs.begin(), ChangedAAs.end()); + + } while (!Worklist.empty() && ++IterationCounter < MaxFixpointIterations); + + LLVM_DEBUG(dbgs() << "\n[Attributor] Fixpoint iteration done after: " + << IterationCounter << "/" << MaxFixpointIterations + << " iterations\n"); + + bool FinishedAtFixpoint = Worklist.empty(); + + // Reset abstract arguments not settled in a sound fixpoint by now. This + // happens when we stopped the fixpoint iteration early. Note that only the + // ones marked as "changed" *and* the ones transitively depending on them + // need to be reverted to a pessimistic state. Others might not be in a + // fixpoint state but we can use the optimistic results for them anyway. + SmallPtrSet<AbstractAttribute *, 32> Visited; + for (unsigned u = 0; u < ChangedAAs.size(); u++) { + AbstractAttribute *ChangedAA = ChangedAAs[u]; + if (!Visited.insert(ChangedAA).second) + continue; + + AbstractState &State = ChangedAA->getState(); + if (!State.isAtFixpoint()) { + State.indicatePessimisticFixpoint(); + + NumAttributesTimedOut++; + } + + auto &QuerriedAAs = QueryMap[ChangedAA]; + ChangedAAs.append(QuerriedAAs.begin(), QuerriedAAs.end()); + } + + LLVM_DEBUG({ + if (!Visited.empty()) + dbgs() << "\n[Attributor] Finalized " << Visited.size() + << " abstract attributes.\n"; + }); + + unsigned NumManifested = 0; + unsigned NumAtFixpoint = 0; + ChangeStatus ManifestChange = ChangeStatus::UNCHANGED; + for (AbstractAttribute *AA : AllAbstractAttributes) { + AbstractState &State = AA->getState(); + + // If there is not already a fixpoint reached, we can now take the + // optimistic state. This is correct because we enforced a pessimistic one + // on abstract attributes that were transitively dependent on a changed one + // already above. + if (!State.isAtFixpoint()) + State.indicateOptimisticFixpoint(); + + // If the state is invalid, we do not try to manifest it. + if (!State.isValidState()) + continue; + + // Manifest the state and record if we changed the IR. + ChangeStatus LocalChange = AA->manifest(*this); + ManifestChange = ManifestChange | LocalChange; + + NumAtFixpoint++; + NumManifested += (LocalChange == ChangeStatus::CHANGED); + } + + (void)NumManifested; + (void)NumAtFixpoint; + LLVM_DEBUG(dbgs() << "\n[Attributor] Manifested " << NumManifested + << " arguments while " << NumAtFixpoint + << " were in a valid fixpoint state\n"); + + // If verification is requested, we finished this run at a fixpoint, and the + // IR was changed, we re-run the whole fixpoint analysis, starting at + // re-initialization of the arguments. This re-run should not result in an IR + // change. Though, the (virtual) state of attributes at the end of the re-run + // might be more optimistic than the known state or the IR state if the better + // state cannot be manifested. + if (VerifyAttributor && FinishedAtFixpoint && + ManifestChange == ChangeStatus::CHANGED) { + VerifyAttributor = false; + ChangeStatus VerifyStatus = run(); + if (VerifyStatus != ChangeStatus::UNCHANGED) + llvm_unreachable( + "Attributor verification failed, re-run did result in an IR change " + "even after a fixpoint was reached in the original run. (False " + "positives possible!)"); + VerifyAttributor = true; + } + + NumAttributesManifested += NumManifested; + NumAttributesValidFixpoint += NumAtFixpoint; + + return ManifestChange; +} + +void Attributor::identifyDefaultAbstractAttributes( + Function &F, InformationCache &InfoCache, + DenseSet</* Attribute::AttrKind */ unsigned> *Whitelist) { + + // Every function can be nounwind. + registerAA(*new AANoUnwindFunction(F, InfoCache)); + + // Every function might be marked "nosync" + registerAA(*new AANoSyncFunction(F, InfoCache)); + + // Every function might be "no-free". + registerAA(*new AANoFreeFunction(F, InfoCache)); + + // Return attributes are only appropriate if the return type is non void. + Type *ReturnType = F.getReturnType(); + if (!ReturnType->isVoidTy()) { + // Argument attribute "returned" --- Create only one per function even + // though it is an argument attribute. + if (!Whitelist || Whitelist->count(AAReturnedValues::ID)) + registerAA(*new AAReturnedValuesImpl(F, InfoCache)); + + // Every function with pointer return type might be marked nonnull. + if (ReturnType->isPointerTy() && + (!Whitelist || Whitelist->count(AANonNullReturned::ID))) + registerAA(*new AANonNullReturned(F, InfoCache)); + } + + // Every argument with pointer type might be marked nonnull. + for (Argument &Arg : F.args()) { + if (Arg.getType()->isPointerTy()) + registerAA(*new AANonNullArgument(Arg, InfoCache)); + } + + // Every function might be "will-return". + registerAA(*new AAWillReturnFunction(F, InfoCache)); + + // Walk all instructions to find more attribute opportunities and also + // interesting instructions that might be queried by abstract attributes + // during their initialization or update. + auto &ReadOrWriteInsts = InfoCache.FuncRWInstsMap[&F]; + auto &InstOpcodeMap = InfoCache.FuncInstOpcodeMap[&F]; + + for (Instruction &I : instructions(&F)) { + bool IsInterestingOpcode = false; + + // To allow easy access to all instructions in a function with a given + // opcode we store them in the InfoCache. As not all opcodes are interesting + // to concrete attributes we only cache the ones that are as identified in + // the following switch. + // Note: There are no concrete attributes now so this is initially empty. + switch (I.getOpcode()) { + default: + assert((!ImmutableCallSite(&I)) && (!isa<CallBase>(&I)) && + "New call site/base instruction type needs to be known int the " + "attributor."); + break; + case Instruction::Call: + case Instruction::CallBr: + case Instruction::Invoke: + case Instruction::CleanupRet: + case Instruction::CatchSwitch: + case Instruction::Resume: + case Instruction::Ret: + IsInterestingOpcode = true; + } + if (IsInterestingOpcode) + InstOpcodeMap[I.getOpcode()].push_back(&I); + if (I.mayReadOrWriteMemory()) + ReadOrWriteInsts.push_back(&I); + + CallSite CS(&I); + if (CS && CS.getCalledFunction()) { + for (int i = 0, e = CS.getCalledFunction()->arg_size(); i < e; i++) { + if (!CS.getArgument(i)->getType()->isPointerTy()) + continue; + + // Call site argument attribute "non-null". + registerAA(*new AANonNullCallSiteArgument(CS, i, InfoCache), i); + } + } + } +} + +/// Helpers to ease debugging through output streams and print calls. +/// +///{ +raw_ostream &llvm::operator<<(raw_ostream &OS, ChangeStatus S) { + return OS << (S == ChangeStatus::CHANGED ? "changed" : "unchanged"); +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, + AbstractAttribute::ManifestPosition AP) { + switch (AP) { + case AbstractAttribute::MP_ARGUMENT: + return OS << "arg"; + case AbstractAttribute::MP_CALL_SITE_ARGUMENT: + return OS << "cs_arg"; + case AbstractAttribute::MP_FUNCTION: + return OS << "fn"; + case AbstractAttribute::MP_RETURNED: + return OS << "fn_ret"; + } + llvm_unreachable("Unknown attribute position!"); +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractState &S) { + return OS << (!S.isValidState() ? "top" : (S.isAtFixpoint() ? "fix" : "")); +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, const AbstractAttribute &AA) { + AA.print(OS); + return OS; +} + +void AbstractAttribute::print(raw_ostream &OS) const { + OS << "[" << getManifestPosition() << "][" << getAsStr() << "][" + << AnchoredVal.getName() << "]"; +} +///} + +/// ---------------------------------------------------------------------------- +/// Pass (Manager) Boilerplate +/// ---------------------------------------------------------------------------- + +static bool runAttributorOnModule(Module &M) { + if (DisableAttributor) + return false; + + LLVM_DEBUG(dbgs() << "[Attributor] Run on module with " << M.size() + << " functions.\n"); + + // Create an Attributor and initially empty information cache that is filled + // while we identify default attribute opportunities. + Attributor A; + InformationCache InfoCache; + + for (Function &F : M) { + // TODO: Not all attributes require an exact definition. Find a way to + // enable deduction for some but not all attributes in case the + // definition might be changed at runtime, see also + // http://lists.llvm.org/pipermail/llvm-dev/2018-February/121275.html. + // TODO: We could always determine abstract attributes and if sufficient + // information was found we could duplicate the functions that do not + // have an exact definition. + if (!F.hasExactDefinition()) { + NumFnWithoutExactDefinition++; + continue; + } + + // For now we ignore naked and optnone functions. + if (F.hasFnAttribute(Attribute::Naked) || + F.hasFnAttribute(Attribute::OptimizeNone)) + continue; + + NumFnWithExactDefinition++; + + // Populate the Attributor with abstract attribute opportunities in the + // function and the information cache with IR information. + A.identifyDefaultAbstractAttributes(F, InfoCache); + } + + return A.run() == ChangeStatus::CHANGED; +} + +PreservedAnalyses AttributorPass::run(Module &M, ModuleAnalysisManager &AM) { + if (runAttributorOnModule(M)) { + // FIXME: Think about passes we will preserve and add them here. + return PreservedAnalyses::none(); + } + return PreservedAnalyses::all(); +} + +namespace { + +struct AttributorLegacyPass : public ModulePass { + static char ID; + + AttributorLegacyPass() : ModulePass(ID) { + initializeAttributorLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + return runAttributorOnModule(M); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + // FIXME: Think about passes we will preserve and add them here. + AU.setPreservesCFG(); + } +}; + +} // end anonymous namespace + +Pass *llvm::createAttributorLegacyPass() { return new AttributorLegacyPass(); } + +char AttributorLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AttributorLegacyPass, "attributor", + "Deduce and propagate attributes", false, false) +INITIALIZE_PASS_END(AttributorLegacyPass, "attributor", + "Deduce and propagate attributes", false, false) diff --git a/lib/Transforms/IPO/BarrierNoopPass.cpp b/lib/Transforms/IPO/BarrierNoopPass.cpp index 05fc3dd6950c..6b68aa90c567 100644 --- a/lib/Transforms/IPO/BarrierNoopPass.cpp +++ b/lib/Transforms/IPO/BarrierNoopPass.cpp @@ -1,9 +1,8 @@ //===- BarrierNoopPass.cpp - A barrier pass for the pass manager ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/BlockExtractor.cpp b/lib/Transforms/IPO/BlockExtractor.cpp index ff5ee817da49..6c365f3f3cbe 100644 --- a/lib/Transforms/IPO/BlockExtractor.cpp +++ b/lib/Transforms/IPO/BlockExtractor.cpp @@ -1,9 +1,8 @@ //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -23,6 +22,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeExtractor.h" + using namespace llvm; #define DEBUG_TYPE "block-extractor" @@ -36,22 +36,48 @@ static cl::opt<std::string> BlockExtractorFile( cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", cl::desc("Erase the existing functions"), cl::Hidden); - namespace { class BlockExtractor : public ModulePass { - SmallVector<BasicBlock *, 16> Blocks; + SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; bool EraseFunctions; - SmallVector<std::pair<std::string, std::string>, 32> BlocksByName; + /// Map a function name to groups of blocks. + SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> + BlocksByName; + + void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract) { + for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks : + GroupsOfBlocksToExtract) { + SmallVector<BasicBlock *, 16> NewGroup; + NewGroup.append(GroupOfBlocks.begin(), GroupOfBlocks.end()); + GroupsOfBlocks.emplace_back(NewGroup); + } + if (!BlockExtractorFile.empty()) + loadFile(); + } public: static char ID; BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) - : ModulePass(ID), Blocks(BlocksToExtract.begin(), BlocksToExtract.end()), - EraseFunctions(EraseFunctions) { - if (!BlockExtractorFile.empty()) - loadFile(); + : ModulePass(ID), EraseFunctions(EraseFunctions) { + // We want one group per element of the input list. + SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks; + for (BasicBlock *BB : BlocksToExtract) { + SmallVector<BasicBlock *, 16> NewGroup; + NewGroup.push_back(BB); + MassagedGroupsOfBlocks.push_back(NewGroup); + } + init(MassagedGroupsOfBlocks); } + + BlockExtractor(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract, + bool EraseFunctions) + : ModulePass(ID), EraseFunctions(EraseFunctions) { + init(GroupsOfBlocksToExtract); + } + BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} bool runOnModule(Module &M) override; @@ -70,6 +96,12 @@ ModulePass *llvm::createBlockExtractorPass( const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { return new BlockExtractor(BlocksToExtract, EraseFunctions); } +ModulePass *llvm::createBlockExtractorPass( + const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract, + bool EraseFunctions) { + return new BlockExtractor(GroupsOfBlocksToExtract, EraseFunctions); +} /// Gets all of the blocks specified in the input file. void BlockExtractor::loadFile() { @@ -82,8 +114,17 @@ void BlockExtractor::loadFile() { Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, /*KeepEmpty=*/false); for (const auto &Line : Lines) { - auto FBPair = Line.split(' '); - BlocksByName.push_back({FBPair.first, FBPair.second}); + SmallVector<StringRef, 4> LineSplit; + Line.split(LineSplit, ' ', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (LineSplit.empty()) + continue; + SmallVector<StringRef, 4> BBNames; + LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (BBNames.empty()) + report_fatal_error("Missing bbs name"); + BlocksByName.push_back({LineSplit[0], {BBNames.begin(), BBNames.end()}}); } } @@ -130,33 +171,46 @@ bool BlockExtractor::runOnModule(Module &M) { } // Get all the blocks specified in the input file. + unsigned NextGroupIdx = GroupsOfBlocks.size(); + GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); for (const auto &BInfo : BlocksByName) { Function *F = M.getFunction(BInfo.first); if (!F) report_fatal_error("Invalid function name specified in the input file"); - auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { - return BB.getName().equals(BInfo.second); - }); - if (Res == F->end()) - report_fatal_error("Invalid block name specified in the input file"); - Blocks.push_back(&*Res); + for (const auto &BBInfo : BInfo.second) { + auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { + return BB.getName().equals(BBInfo); + }); + if (Res == F->end()) + report_fatal_error("Invalid block name specified in the input file"); + GroupsOfBlocks[NextGroupIdx].push_back(&*Res); + } + ++NextGroupIdx; } - // Extract basic blocks. - for (BasicBlock *BB : Blocks) { - // Check if the module contains BB. - if (BB->getParent()->getParent() != &M) - report_fatal_error("Invalid basic block"); - LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " - << BB->getParent()->getName() << ":" << BB->getName() - << "\n"); - SmallVector<BasicBlock *, 2> BlocksToExtractVec; - BlocksToExtractVec.push_back(BB); - if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) - BlocksToExtractVec.push_back(II->getUnwindDest()); - CodeExtractor(BlocksToExtractVec).extractCodeRegion(); - ++NumExtracted; - Changed = true; + // Extract each group of basic blocks. + for (auto &BBs : GroupsOfBlocks) { + SmallVector<BasicBlock *, 32> BlocksToExtractVec; + for (BasicBlock *BB : BBs) { + // Check if the module contains BB. + if (BB->getParent()->getParent() != &M) + report_fatal_error("Invalid basic block"); + LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " + << BB->getParent()->getName() << ":" << BB->getName() + << "\n"); + BlocksToExtractVec.push_back(BB); + if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) + BlocksToExtractVec.push_back(II->getUnwindDest()); + ++NumExtracted; + Changed = true; + } + Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(); + if (F) + LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() + << "' in: " << F->getName() << '\n'); + else + LLVM_DEBUG(dbgs() << "Failed to extract for group '" + << (*BBs.begin())->getName() << "'\n"); } // Erase the functions. diff --git a/lib/Transforms/IPO/CalledValuePropagation.cpp b/lib/Transforms/IPO/CalledValuePropagation.cpp index de62cfc0c1db..20cb3213628e 100644 --- a/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -1,9 +1,8 @@ //===- CalledValuePropagation.cpp - Propagate called values -----*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index 81f3634eaf28..ad877ae1786c 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -1,9 +1,8 @@ //===- ConstantMerge.cpp - Merge duplicate global constants ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -91,6 +90,16 @@ static unsigned getAlignment(GlobalVariable *GV) { return GV->getParent()->getDataLayout().getPreferredAlignment(GV); } +static bool +isUnmergeableGlobal(GlobalVariable *GV, + const SmallPtrSetImpl<const GlobalValue *> &UsedGlobals) { + // Only process constants with initializers in the default address space. + return !GV->isConstant() || !GV->hasDefinitiveInitializer() || + GV->getType()->getAddressSpace() != 0 || GV->hasSection() || + // Don't touch values marked with attribute(used). + UsedGlobals.count(GV); +} + enum class CanMerge { No, Yes }; static CanMerge makeMergeable(GlobalVariable *Old, GlobalVariable *New) { if (!Old->hasGlobalUnnamedAddr() && !New->hasGlobalUnnamedAddr()) @@ -155,11 +164,7 @@ static bool mergeConstants(Module &M) { continue; } - // Only process constants with initializers in the default address space. - if (!GV->isConstant() || !GV->hasDefinitiveInitializer() || - GV->getType()->getAddressSpace() != 0 || GV->hasSection() || - // Don't touch values marked with attribute(used). - UsedGlobals.count(GV)) + if (isUnmergeableGlobal(GV, UsedGlobals)) continue; // This transformation is legal for weak ODR globals in the sense it @@ -197,11 +202,7 @@ static bool mergeConstants(Module &M) { GVI != E; ) { GlobalVariable *GV = &*GVI++; - // Only process constants with initializers in the default address space. - if (!GV->isConstant() || !GV->hasDefinitiveInitializer() || - GV->getType()->getAddressSpace() != 0 || GV->hasSection() || - // Don't touch values marked with attribute(used). - UsedGlobals.count(GV)) + if (isUnmergeableGlobal(GV, UsedGlobals)) continue; // We can only replace constant with local linkage. diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index 666f6cc37bfd..e30b33aa4872 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -1,9 +1,8 @@ //===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -106,10 +105,10 @@ void CrossDSOCFI::buildCFICheck(Module &M) { } LLVMContext &Ctx = M.getContext(); - Constant *C = M.getOrInsertFunction( + FunctionCallee C = M.getOrInsertFunction( "__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx), Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); - Function *F = dyn_cast<Function>(C); + Function *F = dyn_cast<Function>(C.getCallee()); // Take over the existing function. The frontend emits a weak stub so that the // linker knows about the symbol; this pass replaces the function body. F->deleteBody(); @@ -133,9 +132,9 @@ void CrossDSOCFI::buildCFICheck(Module &M) { BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F); IRBuilder<> IRBFail(TrapBB); - Constant *CFICheckFailFn = M.getOrInsertFunction( - "__cfi_check_fail", Type::getVoidTy(Ctx), Type::getInt8PtrTy(Ctx), - Type::getInt8PtrTy(Ctx)); + FunctionCallee CFICheckFailFn = + M.getOrInsertFunction("__cfi_check_fail", Type::getVoidTy(Ctx), + Type::getInt8PtrTy(Ctx), Type::getInt8PtrTy(Ctx)); IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr}); IRBFail.CreateBr(ExitBB); diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index cb30e8f46a54..968a13110b16 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -1,9 +1,8 @@ //===- DeadArgumentElimination.cpp - Eliminate dead arguments -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -939,7 +938,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), Args, OpBundles, "", Call->getParent()); } else { - NewCS = CallInst::Create(NF, Args, OpBundles, "", Call); + NewCS = CallInst::Create(NFTy, NF, Args, OpBundles, "", Call); cast<CallInst>(NewCS.getInstruction()) ->setTailCallKind(cast<CallInst>(Call)->getTailCallKind()); } diff --git a/lib/Transforms/IPO/ElimAvailExtern.cpp b/lib/Transforms/IPO/ElimAvailExtern.cpp index d5fef59286dd..fc52db562c62 100644 --- a/lib/Transforms/IPO/ElimAvailExtern.cpp +++ b/lib/Transforms/IPO/ElimAvailExtern.cpp @@ -1,9 +1,8 @@ //===- ElimAvailExtern.cpp - DCE unreachable internal functions -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index a744d7f2d2d9..f77b528fc42d 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -1,9 +1,8 @@ //===-- ExtractGV.cpp - Global Value extraction pass ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 4dc1529ddbf5..b38cb6d0ed3f 100644 --- a/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -1,9 +1,8 @@ //===- ForceFunctionAttrs.cpp - Force function attrs for debugging --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -58,6 +57,7 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { .Case("sanitize_hwaddress", Attribute::SanitizeHWAddress) .Case("sanitize_memory", Attribute::SanitizeMemory) .Case("sanitize_thread", Attribute::SanitizeThread) + .Case("sanitize_memtag", Attribute::SanitizeMemTag) .Case("speculative_load_hardening", Attribute::SpeculativeLoadHardening) .Case("ssp", Attribute::StackProtect) .Case("sspreq", Attribute::StackProtectReq) diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 4e2a82b56eec..5ccd8bc4b0fb 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -1,9 +1,8 @@ //===- FunctionAttrs.cpp - Pass which marks functions attributes ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -28,6 +27,7 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -76,6 +76,7 @@ STATISTIC(NumNoAlias, "Number of function returns marked noalias"); STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull"); STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); STATISTIC(NumNoUnwind, "Number of functions marked as nounwind"); +STATISTIC(NumNoFree, "Number of functions marked as nofree"); // FIXME: This is disabled by default to avoid exposing security vulnerabilities // in C/C++ code compiled by clang: @@ -89,6 +90,10 @@ static cl::opt<bool> DisableNoUnwindInference( "disable-nounwind-inference", cl::Hidden, cl::desc("Stop inferring nounwind attribute during function-attrs pass")); +static cl::opt<bool> DisableNoFreeInference( + "disable-nofree-inference", cl::Hidden, + cl::desc("Stop inferring nofree attribute during function-attrs pass")); + namespace { using SCCNodeSet = SmallSetVector<Function *, 8>; @@ -256,12 +261,15 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { } } + // If the SCC contains both functions that read and functions that write, then + // we cannot add readonly attributes. + if (ReadsMemory && WritesMemory) + return false; + // Success! Functions in this SCC do not access memory, or only read memory. // Give them the appropriate attribute. bool MadeChange = false; - assert(!(ReadsMemory && WritesMemory) && - "Function marked read-only and write-only"); for (Function *F : SCCNodes) { if (F->doesNotAccessMemory()) // Already perfect! @@ -1228,6 +1236,25 @@ static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) { return true; } +/// Helper for NoFree inference predicate InstrBreaksAttribute. +static bool InstrBreaksNoFree(Instruction &I, const SCCNodeSet &SCCNodes) { + CallSite CS(&I); + if (!CS) + return false; + + Function *Callee = CS.getCalledFunction(); + if (!Callee) + return true; + + if (Callee->doesNotFreeMemory()) + return false; + + if (SCCNodes.count(Callee) > 0) + return false; + + return true; +} + /// Infer attributes from all functions in the SCC by scanning every /// instruction for compliance to the attribute assumptions. Currently it /// does: @@ -1281,6 +1308,29 @@ static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { }, /* RequiresExactDefinition= */ true}); + if (!DisableNoFreeInference) + // Request to infer nofree attribute for all the functions in the SCC if + // every callsite within the SCC does not directly or indirectly free + // memory (except for calls to functions within the SCC). Note that nofree + // attribute suffers from derefinement - results may change depending on + // how functions are optimized. Thus it can be inferred only from exact + // definitions. + AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ + Attribute::NoFree, + // Skip functions known not to free memory. + [](const Function &F) { return F.doesNotFreeMemory(); }, + // Instructions that break non-deallocating assumption. + [SCCNodes](Instruction &I) { + return InstrBreaksNoFree(I, SCCNodes); + }, + [](Function &F) { + LLVM_DEBUG(dbgs() + << "Adding nofree attr to fn " << F.getName() << "\n"); + F.setDoesNotFreeMemory(); + ++NumNoFree; + }, + /* RequiresExactDefinition= */ true}); + // Perform all the requested attribute inference actions. return AI.run(SCCNodes); } @@ -1301,7 +1351,7 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { return false; Function *F = *SCCNodes.begin(); - if (!F || F->isDeclaration() || F->doesNotRecurse()) + if (!F || !F->hasExactDefinition() || F->doesNotRecurse()) return false; // If all of the calls in F are identifiable and are to norecurse functions, F @@ -1323,7 +1373,8 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { } template <typename AARGetterT> -static bool deriveAttrsInPostOrder(SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, +static bool deriveAttrsInPostOrder(SCCNodeSet &SCCNodes, + AARGetterT &&AARGetter, bool HasUnknownCall) { bool Changed = false; @@ -1367,8 +1418,7 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, bool HasUnknownCall = false; for (LazyCallGraph::Node &N : C) { Function &F = N.getFunction(); - if (F.hasFnAttribute(Attribute::OptimizeNone) || - F.hasFnAttribute(Attribute::Naked)) { + if (F.hasOptNone() || F.hasFnAttribute(Attribute::Naked)) { // Treat any function we're trying not to optimize as if it were an // indirect call and omit it from the node set used below. HasUnknownCall = true; @@ -1441,8 +1491,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { bool ExternalNode = false; for (CallGraphNode *I : SCC) { Function *F = I->getFunction(); - if (!F || F->hasFnAttribute(Attribute::OptimizeNone) || - F->hasFnAttribute(Attribute::Naked)) { + if (!F || F->hasOptNone() || F->hasFnAttribute(Attribute::Naked)) { // External node or function we're trying not to optimize - we both avoid // transform them and avoid leveraging information they provide. ExternalNode = true; diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 1223a23512ed..62c7fbd07223 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -1,9 +1,8 @@ //===- FunctionImport.cpp - ThinLTO Summary-based Function Import ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -778,9 +777,7 @@ void llvm::computeDeadSymbols( if (!VI) return; - // We need to make sure all variants of the symbol are scanned, alias can - // make one (but not all) alive. - if (llvm::all_of(VI.getSummaryList(), + if (llvm::any_of(VI.getSummaryList(), [](const std::unique_ptr<llvm::GlobalValueSummary> &S) { return S->isLive(); })) @@ -820,12 +817,18 @@ void llvm::computeDeadSymbols( while (!Worklist.empty()) { auto VI = Worklist.pop_back_val(); for (auto &Summary : VI.getSummaryList()) { - GlobalValueSummary *Base = Summary->getBaseObject(); - // Set base value live in case it is an alias. - Base->setLive(true); - for (auto Ref : Base->refs()) + if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { + // If this is an alias, visit the aliasee VI to ensure that all copies + // are marked live and it is added to the worklist for further + // processing of its references. + visit(AS->getAliaseeVI()); + continue; + } + + Summary->setLive(true); + for (auto Ref : Summary->refs()) visit(Ref); - if (auto *FS = dyn_cast<FunctionSummary>(Base)) + if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) for (auto Call : FS->calls()) visit(Call.first); } @@ -847,14 +850,16 @@ void llvm::computeDeadSymbolsWithConstProp( bool ImportEnabled) { computeDeadSymbols(Index, GUIDPreservedSymbols, isPrevailing); if (ImportEnabled) { - Index.propagateConstants(GUIDPreservedSymbols); + Index.propagateAttributes(GUIDPreservedSymbols); } else { - // If import is disabled we should drop read-only attribute + // If import is disabled we should drop read/write-only attribute // from all summaries to prevent internalization. for (auto &P : Index) for (auto &S : P.second.SummaryList) - if (auto *GVS = dyn_cast<GlobalVarSummary>(S.get())) + if (auto *GVS = dyn_cast<GlobalVarSummary>(S.get())) { GVS->setReadOnly(false); + GVS->setWriteOnly(false); + } } } @@ -973,12 +978,15 @@ void llvm::thinLTOResolvePrevailingInModule( // changed to enable this for aliases. llvm_unreachable("Expected GV to be converted"); } else { - // If the original symbols has global unnamed addr and linkonce_odr linkage, - // it should be an auto hide symbol. Add hidden visibility to the symbol to - // preserve the property. - if (GV.hasLinkOnceODRLinkage() && GV.hasGlobalUnnamedAddr() && - NewLinkage == GlobalValue::WeakODRLinkage) + // If all copies of the original symbol had global unnamed addr and + // linkonce_odr linkage, it should be an auto hide symbol. In that case + // the thin link would have marked it as CanAutoHide. Add hidden visibility + // to the symbol to preserve the property. + if (NewLinkage == GlobalValue::WeakODRLinkage && + GS->second->canAutoHide()) { + assert(GV.hasLinkOnceODRLinkage() && GV.hasGlobalUnnamedAddr()); GV.setVisibility(GlobalValue::HiddenVisibility); + } LLVM_DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " << GV.getLinkage() << " to " << NewLinkage @@ -1047,9 +1055,10 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { ValueToValueMapTy VMap; Function *NewFn = CloneFunction(Fn, VMap); - // Clone should use the original alias's linkage and name, and we ensure - // all uses of alias instead use the new clone (casted if necessary). + // Clone should use the original alias's linkage, visibility and name, and we + // ensure all uses of alias instead use the new clone (casted if necessary). NewFn->setLinkage(GA->getLinkage()); + NewFn->setVisibility(GA->getVisibility()); GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewFn, GA->getType())); NewFn->takeName(GA); return NewFn; @@ -1057,7 +1066,7 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { // Internalize values that we marked with specific attribute // in processGlobalForThinLTO. -static void internalizeImmutableGVs(Module &M) { +static void internalizeGVsAfterImport(Module &M) { for (auto &GV : M.globals()) // Skip GVs which have been converted to declarations // by dropDeadSymbols. @@ -1190,7 +1199,7 @@ Expected<bool> FunctionImporter::importFunctions( NumImportedModules++; } - internalizeImmutableGVs(DestModule); + internalizeGVsAfterImport(DestModule); NumImportedFunctions += (ImportedCount - ImportedGVCount); NumImportedGlobalVars += ImportedGVCount; diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index 34de87433367..86b7f3e49ee6 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -1,9 +1,8 @@ //===-- GlobalDCE.cpp - DCE unreachable internal functions ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 3005aafd06b1..c4fb3ce77f6e 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -1,9 +1,8 @@ //===- GlobalOpt.cpp - Optimize Global Variables --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -730,7 +729,8 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { break; if (Idxs.size() == GEPI->getNumOperands()-1) Changed |= OptimizeAwayTrappingUsesOfValue( - GEPI, ConstantExpr::getGetElementPtr(nullptr, NewV, Idxs)); + GEPI, ConstantExpr::getGetElementPtr(GEPI->getSourceElementType(), + NewV, Idxs)); if (GEPI->use_empty()) { Changed = true; GEPI->eraseFromParent(); @@ -906,9 +906,10 @@ OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, // Replace the cmp X, 0 with a use of the bool value. // Sink the load to where the compare was, if atomic rules allow us to. - Value *LV = new LoadInst(InitBool, InitBool->getName()+".val", false, 0, + Value *LV = new LoadInst(InitBool->getValueType(), InitBool, + InitBool->getName() + ".val", false, 0, LI->getOrdering(), LI->getSyncScopeID(), - LI->isUnordered() ? (Instruction*)ICI : LI); + LI->isUnordered() ? (Instruction *)ICI : LI); InitBoolUsed = true; switch (ICI->getPredicate()) { default: llvm_unreachable("Unknown ICmp Predicate!"); @@ -1041,7 +1042,8 @@ static void ReplaceUsesOfMallocWithGlobal(Instruction *Alloc, } // Insert a load from the global, and use it instead of the malloc. - Value *NL = new LoadInst(GV, GV->getName()+".val", InsertPt); + Value *NL = + new LoadInst(GV->getValueType(), GV, GV->getName() + ".val", InsertPt); U->replaceUsesOfWith(Alloc, NL); } } @@ -1164,10 +1166,10 @@ static Value *GetHeapSROAValue(Value *V, unsigned FieldNo, if (LoadInst *LI = dyn_cast<LoadInst>(V)) { // This is a scalarized version of the load from the global. Just create // a new Load of the scalarized global. - Result = new LoadInst(GetHeapSROAValue(LI->getOperand(0), FieldNo, - InsertedScalarizedValues, - PHIsToRewrite), - LI->getName()+".f"+Twine(FieldNo), LI); + Value *V = GetHeapSROAValue(LI->getOperand(0), FieldNo, + InsertedScalarizedValues, PHIsToRewrite); + Result = new LoadInst(V->getType()->getPointerElementType(), V, + LI->getName() + ".f" + Twine(FieldNo), LI); } else { PHINode *PN = cast<PHINode>(V); // PN's type is pointer to struct. Make a new PHI of pointer to struct @@ -1357,7 +1359,9 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, // Within the NullPtrBlock, we need to emit a comparison and branch for each // pointer, because some may be null while others are not. for (unsigned i = 0, e = FieldGlobals.size(); i != e; ++i) { - Value *GVVal = new LoadInst(FieldGlobals[i], "tmp", NullPtrBlock); + Value *GVVal = + new LoadInst(cast<GlobalVariable>(FieldGlobals[i])->getValueType(), + FieldGlobals[i], "tmp", NullPtrBlock); Value *Cmp = new ICmpInst(*NullPtrBlock, ICmpInst::ICMP_NE, GVVal, Constant::getNullValue(GVVal->getType())); BasicBlock *FreeBlock = BasicBlock::Create(Cmp->getContext(), "free_it", @@ -1650,6 +1654,9 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { for(auto *GVe : GVs){ DIGlobalVariable *DGV = GVe->getVariable(); DIExpression *E = GVe->getExpression(); + const DataLayout &DL = GV->getParent()->getDataLayout(); + unsigned SizeInOctets = + DL.getTypeAllocSizeInBits(NewGV->getType()->getElementType()) / 8; // It is expected that the address of global optimized variable is on // top of the stack. After optimization, value of that variable will @@ -1660,10 +1667,12 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { // DW_OP_deref DW_OP_constu <ValMinus> // DW_OP_mul DW_OP_constu <ValInit> DW_OP_plus DW_OP_stack_value SmallVector<uint64_t, 12> Ops = { - dwarf::DW_OP_deref, dwarf::DW_OP_constu, ValMinus, - dwarf::DW_OP_mul, dwarf::DW_OP_constu, ValInit, + dwarf::DW_OP_deref_size, SizeInOctets, + dwarf::DW_OP_constu, ValMinus, + dwarf::DW_OP_mul, dwarf::DW_OP_constu, ValInit, dwarf::DW_OP_plus}; - E = DIExpression::prependOpcodes(E, Ops, DIExpression::WithStackValue); + bool WithStackValue = true; + E = DIExpression::prependOpcodes(E, Ops, WithStackValue); DIGlobalVariableExpression *DGVE = DIGlobalVariableExpression::get(NewGV->getContext(), DGV, E); NewGV->addDebugInfo(DGVE); @@ -1701,7 +1710,8 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { if (LoadInst *LI = dyn_cast<LoadInst>(StoredVal)) { assert(LI->getOperand(0) == GV && "Not a copy!"); // Insert a new load, to preserve the saved value. - StoreVal = new LoadInst(NewGV, LI->getName()+".b", false, 0, + StoreVal = new LoadInst(NewGV->getValueType(), NewGV, + LI->getName() + ".b", false, 0, LI->getOrdering(), LI->getSyncScopeID(), LI); } else { assert((isa<CastInst>(StoredVal) || isa<SelectInst>(StoredVal)) && @@ -1717,8 +1727,9 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { } else { // Change the load into a load of bool then a select. LoadInst *LI = cast<LoadInst>(UI); - LoadInst *NLI = new LoadInst(NewGV, LI->getName()+".b", false, 0, - LI->getOrdering(), LI->getSyncScopeID(), LI); + LoadInst *NLI = + new LoadInst(NewGV->getValueType(), NewGV, LI->getName() + ".b", + false, 0, LI->getOrdering(), LI->getSyncScopeID(), LI); Instruction *NSI; if (IsOneZero) NSI = new ZExtInst(NLI, LI->getType(), "", LI); @@ -1970,7 +1981,12 @@ static bool processInternalGlobal( } if (GS.StoredType <= GlobalStatus::InitializerStored) { LLVM_DEBUG(dbgs() << "MARKING CONSTANT: " << *GV << "\n"); - GV->setConstant(true); + + // Don't actually mark a global constant if it's atomic because atomic loads + // are implemented by a trivial cmpxchg in some edge-cases and that usually + // requires write access to the variable even if it's not actually changed. + if (GS.Ordering == AtomicOrdering::NotAtomic) + GV->setConstant(true); // Clean up any obviously simplifiable users now. CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI); @@ -2084,21 +2100,21 @@ static void ChangeCalleesToFastCall(Function *F) { } } -static AttributeList StripNest(LLVMContext &C, AttributeList Attrs) { - // There can be at most one attribute set with a nest attribute. - unsigned NestIndex; - if (Attrs.hasAttrSomewhere(Attribute::Nest, &NestIndex)) - return Attrs.removeAttribute(C, NestIndex, Attribute::Nest); +static AttributeList StripAttr(LLVMContext &C, AttributeList Attrs, + Attribute::AttrKind A) { + unsigned AttrIndex; + if (Attrs.hasAttrSomewhere(A, &AttrIndex)) + return Attrs.removeAttribute(C, AttrIndex, A); return Attrs; } -static void RemoveNestAttribute(Function *F) { - F->setAttributes(StripNest(F->getContext(), F->getAttributes())); +static void RemoveAttribute(Function *F, Attribute::AttrKind A) { + F->setAttributes(StripAttr(F->getContext(), F->getAttributes(), A)); for (User *U : F->users()) { if (isa<BlockAddress>(U)) continue; CallSite CS(cast<Instruction>(U)); - CS.setAttributes(StripNest(F->getContext(), CS.getAttributes())); + CS.setAttributes(StripAttr(F->getContext(), CS.getAttributes(), A)); } } @@ -2113,13 +2129,6 @@ static bool hasChangeableCC(Function *F) { if (CC != CallingConv::C && CC != CallingConv::X86_ThisCall) return false; - // Don't break the invariant that the inalloca parameter is the only parameter - // passed in memory. - // FIXME: GlobalOpt should remove inalloca when possible and hoist the dynamic - // alloca it uses to the entry block if possible. - if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca)) - return false; - // FIXME: Change CC for the whole chain of musttail calls when possible. // // Can't change CC of the function that either has musttail calls, or is a @@ -2281,6 +2290,17 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, if (!F->hasLocalLinkage()) continue; + // If we have an inalloca parameter that we can safely remove the + // inalloca attribute from, do so. This unlocks optimizations that + // wouldn't be safe in the presence of inalloca. + // FIXME: We should also hoist alloca affected by this to the entry + // block if possible. + if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca) && + !F->hasAddressTaken()) { + RemoveAttribute(F, Attribute::InAlloca); + Changed = true; + } + if (hasChangeableCC(F) && !F->isVarArg() && !F->hasAddressTaken()) { NumInternalFunc++; TargetTransformInfo &TTI = GetTTI(*F); @@ -2289,8 +2309,8 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, // cold at all call sites and the callers contain no other non coldcc // calls. if (EnableColdCCStressTest || - (isValidCandidateForColdCC(*F, GetBFI, AllCallsCold) && - TTI.useColdCCForColdCall(*F))) { + (TTI.useColdCCForColdCall(*F) && + isValidCandidateForColdCC(*F, GetBFI, AllCallsCold))) { F->setCallingConv(CallingConv::Cold); changeCallSitesToColdCC(F); Changed = true; @@ -2313,7 +2333,7 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, !F->hasAddressTaken()) { // The function is not used by a trampoline intrinsic, so it is safe // to remove the 'nest' attribute. - RemoveNestAttribute(F); + RemoveAttribute(F, Attribute::Nest); ++NumNestRemoved; Changed = true; } @@ -2808,46 +2828,20 @@ static Function *FindCXAAtExit(Module &M, TargetLibraryInfo *TLI) { /// Returns whether the given function is an empty C++ destructor and can /// therefore be eliminated. /// Note that we assume that other optimization passes have already simplified -/// the code so we only look for a function with a single basic block, where -/// the only allowed instructions are 'ret', 'call' to an empty C++ dtor and -/// other side-effect free instructions. -static bool cxxDtorIsEmpty(const Function &Fn, - SmallPtrSet<const Function *, 8> &CalledFunctions) { +/// the code so we simply check for 'ret'. +static bool cxxDtorIsEmpty(const Function &Fn) { // FIXME: We could eliminate C++ destructors if they're readonly/readnone and // nounwind, but that doesn't seem worth doing. if (Fn.isDeclaration()) return false; - if (++Fn.begin() != Fn.end()) - return false; - - const BasicBlock &EntryBlock = Fn.getEntryBlock(); - for (BasicBlock::const_iterator I = EntryBlock.begin(), E = EntryBlock.end(); - I != E; ++I) { - if (const CallInst *CI = dyn_cast<CallInst>(I)) { - // Ignore debug intrinsics. - if (isa<DbgInfoIntrinsic>(CI)) - continue; - - const Function *CalledFn = CI->getCalledFunction(); - - if (!CalledFn) - return false; - - SmallPtrSet<const Function *, 8> NewCalledFunctions(CalledFunctions); - - // Don't treat recursive functions as empty. - if (!NewCalledFunctions.insert(CalledFn).second) - return false; - - if (!cxxDtorIsEmpty(*CalledFn, NewCalledFunctions)) - return false; - } else if (isa<ReturnInst>(*I)) - return true; // We're done. - else if (I->mayHaveSideEffects()) - return false; // Destructor with side effects, bail. + for (auto &I : Fn.getEntryBlock()) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + if (isa<ReturnInst>(I)) + return true; + break; } - return false; } @@ -2879,11 +2873,7 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { Function *DtorFn = dyn_cast<Function>(CI->getArgOperand(0)->stripPointerCasts()); - if (!DtorFn) - continue; - - SmallPtrSet<const Function *, 8> CalledFunctions; - if (!cxxDtorIsEmpty(*DtorFn, CalledFunctions)) + if (!DtorFn || !cxxDtorIsEmpty(*DtorFn)) continue; // Just remove the call. diff --git a/lib/Transforms/IPO/GlobalSplit.cpp b/lib/Transforms/IPO/GlobalSplit.cpp index 792f4b3052a3..060043a40b89 100644 --- a/lib/Transforms/IPO/GlobalSplit.cpp +++ b/lib/Transforms/IPO/GlobalSplit.cpp @@ -1,9 +1,8 @@ //===- GlobalSplit.cpp - global variable splitter -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/HotColdSplitting.cpp b/lib/Transforms/IPO/HotColdSplitting.cpp index 924a7d5fbd9c..ab1a9a79cad6 100644 --- a/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/lib/Transforms/IPO/HotColdSplitting.cpp @@ -1,16 +1,28 @@ //===- HotColdSplitting.cpp -- Outline Cold Regions -------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// Outline cold regions to a separate function. -// TODO: Update BFI and BPI -// TODO: Add all the outlined functions to a separate section. -// +/// +/// \file +/// The goal of hot/cold splitting is to improve the memory locality of code. +/// The splitting pass does this by identifying cold blocks and moving them into +/// separate functions. +/// +/// When the splitting pass finds a cold block (referred to as "the sink"), it +/// grows a maximal cold region around that block. The maximal region contains +/// all blocks (post-)dominated by the sink [*]. In theory, these blocks are as +/// cold as the sink. Once a region is found, it's split out of the original +/// function provided it's profitable to do so. +/// +/// [*] In practice, there is some added complexity because some blocks are not +/// safe to extract. +/// +/// TODO: Use the PM to get domtrees, and preserve BFI/BPI. +/// TODO: Reorder outlined functions. +/// //===----------------------------------------------------------------------===// #include "llvm/ADT/PostOrderIterator.h" @@ -53,7 +65,6 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -69,16 +80,12 @@ static cl::opt<bool> EnableStaticAnalyis("hot-cold-static-analysis", cl::init(true), cl::Hidden); static cl::opt<int> - MinOutliningThreshold("min-outlining-thresh", cl::init(3), cl::Hidden, - cl::desc("Code size threshold for outlining within a " - "single BB (as a multiple of TCC_Basic)")); + SplittingThreshold("hotcoldsplit-threshold", cl::init(2), cl::Hidden, + cl::desc("Base penalty for splitting cold code (as a " + "multiple of TCC_Basic)")); namespace { -struct PostDomTree : PostDomTreeBase<BasicBlock> { - PostDomTree(Function &F) { recalculate(F); } -}; - /// A sequence of basic blocks. /// /// A 0-sized SmallVector is slightly cheaper to move than a std::vector. @@ -101,13 +108,14 @@ bool blockEndsInUnreachable(const BasicBlock &BB) { bool unlikelyExecuted(BasicBlock &BB) { // Exception handling blocks are unlikely executed. - if (BB.isEHPad()) + if (BB.isEHPad() || isa<ResumeInst>(BB.getTerminator())) return true; - // The block is cold if it calls/invokes a cold function. + // The block is cold if it calls/invokes a cold function. However, do not + // mark sanitizer traps as cold. for (Instruction &I : BB) if (auto CS = CallSite(&I)) - if (CS.hasFnAttr(Attribute::Cold)) + if (CS.hasFnAttr(Attribute::Cold) && !CS->getMetadata("nosanitize")) return true; // The block is cold if it has an unreachable terminator, unless it's @@ -125,38 +133,39 @@ bool unlikelyExecuted(BasicBlock &BB) { /// Check whether it's safe to outline \p BB. static bool mayExtractBlock(const BasicBlock &BB) { - return !BB.hasAddressTaken() && !BB.isEHPad(); -} - -/// Check whether \p Region is profitable to outline. -static bool isProfitableToOutline(const BlockSequence &Region, - TargetTransformInfo &TTI) { - if (Region.size() > 1) - return true; - - int Cost = 0; - const BasicBlock &BB = *Region[0]; - for (const Instruction &I : BB) { - if (isa<DbgInfoIntrinsic>(&I) || &I == BB.getTerminator()) - continue; - - Cost += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); - - if (Cost >= (MinOutliningThreshold * TargetTransformInfo::TCC_Basic)) - return true; - } - return false; + // EH pads are unsafe to outline because doing so breaks EH type tables. It + // follows that invoke instructions cannot be extracted, because CodeExtractor + // requires unwind destinations to be within the extraction region. + // + // Resumes that are not reachable from a cleanup landing pad are considered to + // be unreachable. It’s not safe to split them out either. + auto Term = BB.getTerminator(); + return !BB.hasAddressTaken() && !BB.isEHPad() && !isa<InvokeInst>(Term) && + !isa<ResumeInst>(Term); } -/// Mark \p F cold. Return true if it's changed. -static bool markEntireFunctionCold(Function &F) { - assert(!F.hasFnAttribute(Attribute::OptimizeNone) && "Can't mark this cold"); +/// Mark \p F cold. Based on this assumption, also optimize it for minimum size. +/// If \p UpdateEntryCount is true (set when this is a new split function and +/// module has profile data), set entry count to 0 to ensure treated as cold. +/// Return true if the function is changed. +static bool markFunctionCold(Function &F, bool UpdateEntryCount = false) { + assert(!F.hasOptNone() && "Can't mark this cold"); bool Changed = false; + if (!F.hasFnAttribute(Attribute::Cold)) { + F.addFnAttr(Attribute::Cold); + Changed = true; + } if (!F.hasFnAttribute(Attribute::MinSize)) { F.addFnAttr(Attribute::MinSize); Changed = true; } - // TODO: Move this function into a cold section. + if (UpdateEntryCount) { + // Set the entry count to 0 to ensure it is placed in the unlikely text + // section when function sections are enabled. + F.setEntryCount(0); + Changed = true; + } + return Changed; } @@ -165,24 +174,24 @@ public: HotColdSplitting(ProfileSummaryInfo *ProfSI, function_ref<BlockFrequencyInfo *(Function &)> GBFI, function_ref<TargetTransformInfo &(Function &)> GTTI, - std::function<OptimizationRemarkEmitter &(Function &)> *GORE) - : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {} + std::function<OptimizationRemarkEmitter &(Function &)> *GORE, + function_ref<AssumptionCache *(Function &)> LAC) + : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {} bool run(Module &M); private: + bool isFunctionCold(const Function &F) const; bool shouldOutlineFrom(const Function &F) const; - bool outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, - BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - DominatorTree &DT, PostDomTree &PDT, - OptimizationRemarkEmitter &ORE); + bool outlineColdRegions(Function &F, bool HasProfileSummary); Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, unsigned Count); - SmallPtrSet<const Function *, 2> OutlinedFunctions; + OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count); ProfileSummaryInfo *PSI; function_ref<BlockFrequencyInfo *(Function &)> GetBFI; function_ref<TargetTransformInfo &(Function &)> GetTTI; std::function<OptimizationRemarkEmitter &(Function &)> *GetORE; + function_ref<AssumptionCache *(Function &)> LookupAC; }; class HotColdSplittingLegacyPass : public ModulePass { @@ -193,10 +202,10 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<BlockFrequencyInfoWrapperPass>(); AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addUsedIfAvailable<AssumptionCacheTracker>(); } bool runOnModule(Module &M) override; @@ -204,59 +213,141 @@ public: } // end anonymous namespace -// Returns false if the function should not be considered for hot-cold split -// optimization. -bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { - // Do not try to outline again from an already outlined cold function. - if (OutlinedFunctions.count(&F)) - return false; +/// Check whether \p F is inherently cold. +bool HotColdSplitting::isFunctionCold(const Function &F) const { + if (F.hasFnAttribute(Attribute::Cold)) + return true; - if (F.size() <= 2) - return false; + if (F.getCallingConv() == CallingConv::Cold) + return true; - // TODO: Consider only skipping functions marked `optnone` or `cold`. + if (PSI->isFunctionEntryCold(&F)) + return true; - if (F.hasAddressTaken()) - return false; + return false; +} +// Returns false if the function should not be considered for hot-cold split +// optimization. +bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { if (F.hasFnAttribute(Attribute::AlwaysInline)) return false; if (F.hasFnAttribute(Attribute::NoInline)) return false; - if (F.getCallingConv() == CallingConv::Cold) + if (F.hasFnAttribute(Attribute::SanitizeAddress) || + F.hasFnAttribute(Attribute::SanitizeHWAddress) || + F.hasFnAttribute(Attribute::SanitizeThread) || + F.hasFnAttribute(Attribute::SanitizeMemory)) return false; - if (PSI->isFunctionEntryCold(&F)) - return false; return true; } +/// Get the benefit score of outlining \p Region. +static int getOutliningBenefit(ArrayRef<BasicBlock *> Region, + TargetTransformInfo &TTI) { + // Sum up the code size costs of non-terminator instructions. Tight coupling + // with \ref getOutliningPenalty is needed to model the costs of terminators. + int Benefit = 0; + for (BasicBlock *BB : Region) + for (Instruction &I : BB->instructionsWithoutDebug()) + if (&I != BB->getTerminator()) + Benefit += + TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); + + return Benefit; +} + +/// Get the penalty score for outlining \p Region. +static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, + unsigned NumInputs, unsigned NumOutputs) { + int Penalty = SplittingThreshold; + LLVM_DEBUG(dbgs() << "Applying penalty for splitting: " << Penalty << "\n"); + + // If the splitting threshold is set at or below zero, skip the usual + // profitability check. + if (SplittingThreshold <= 0) + return Penalty; + + // The typical code size cost for materializing an argument for the outlined + // call. + LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumInputs << " inputs\n"); + const int CostForArgMaterialization = TargetTransformInfo::TCC_Basic; + Penalty += CostForArgMaterialization * NumInputs; + + // The typical code size cost for an output alloca, its associated store, and + // its associated reload. + LLVM_DEBUG(dbgs() << "Applying penalty for: " << NumOutputs << " outputs\n"); + const int CostForRegionOutput = 3 * TargetTransformInfo::TCC_Basic; + Penalty += CostForRegionOutput * NumOutputs; + + // Find the number of distinct exit blocks for the region. Use a conservative + // check to determine whether control returns from the region. + bool NoBlocksReturn = true; + SmallPtrSet<BasicBlock *, 2> SuccsOutsideRegion; + for (BasicBlock *BB : Region) { + // If a block has no successors, only assume it does not return if it's + // unreachable. + if (succ_empty(BB)) { + NoBlocksReturn &= isa<UnreachableInst>(BB->getTerminator()); + continue; + } + + for (BasicBlock *SuccBB : successors(BB)) { + if (find(Region, SuccBB) == Region.end()) { + NoBlocksReturn = false; + SuccsOutsideRegion.insert(SuccBB); + } + } + } + + // Apply a `noreturn` bonus. + if (NoBlocksReturn) { + LLVM_DEBUG(dbgs() << "Applying bonus for: " << Region.size() + << " non-returning terminators\n"); + Penalty -= Region.size(); + } + + // Apply a penalty for having more than one successor outside of the region. + // This penalty accounts for the switch needed in the caller. + if (!SuccsOutsideRegion.empty()) { + LLVM_DEBUG(dbgs() << "Applying penalty for: " << SuccsOutsideRegion.size() + << " non-region successors\n"); + Penalty += (SuccsOutsideRegion.size() - 1) * TargetTransformInfo::TCC_Basic; + } + + return Penalty; +} + Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count) { assert(!Region.empty()); // TODO: Pass BFI and BPI to update profile information. CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, - /* BPI */ nullptr, /* AllowVarArgs */ false, + /* BPI */ nullptr, AC, /* AllowVarArgs */ false, /* AllowAlloca */ false, /* Suffix */ "cold." + std::to_string(Count)); + // Perform a simple cost/benefit analysis to decide whether or not to permit + // splitting. SetVector<Value *> Inputs, Outputs, Sinks; CE.findInputsOutputs(Inputs, Outputs, Sinks); - - // Do not extract regions that have live exit variables. - if (Outputs.size() > 0) { - LLVM_DEBUG(llvm::dbgs() << "Not outlining; live outputs\n"); + int OutliningBenefit = getOutliningBenefit(Region, TTI); + int OutliningPenalty = + getOutliningPenalty(Region, Inputs.size(), Outputs.size()); + LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit + << ", penalty = " << OutliningPenalty << "\n"); + if (OutliningBenefit <= OutliningPenalty) return nullptr; - } - // TODO: Run MergeBasicBlockIntoOnlyPred on the outlined function. Function *OrigF = Region[0]->getParent(); if (Function *OutF = CE.extractCodeRegion()) { User *U = *OutF->user_begin(); @@ -269,9 +360,7 @@ Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, } CI->setIsNoInline(); - // Try to make the outlined code as small as possible on the assumption - // that it's cold. - markEntireFunctionCold(*OutF); + markFunctionCold(*OutF, BFI != nullptr); LLVM_DEBUG(llvm::dbgs() << "Outlined Region: " << *OutF); ORE.emit([&]() { @@ -298,6 +387,8 @@ using BlockTy = std::pair<BasicBlock *, unsigned>; namespace { /// A maximal outlining region. This contains all blocks post-dominated by a /// sink block, the sink block itself, and all blocks dominated by the sink. +/// If sink-predecessors and sink-successors cannot be extracted in one region, +/// the static constructor returns a list of suitable extraction regions. class OutliningRegion { /// A list of (block, score) pairs. A block's score is non-zero iff it's a /// viable sub-region entry point. Blocks with higher scores are better entry @@ -312,12 +403,9 @@ class OutliningRegion { /// Whether the entire function is cold. bool EntireFunctionCold = false; - /// Whether or not \p BB could be the entry point of an extracted region. - static bool isViableEntryPoint(BasicBlock &BB) { return !BB.isEHPad(); } - /// If \p BB is a viable entry point, return \p Score. Return 0 otherwise. static unsigned getEntryPointScore(BasicBlock &BB, unsigned Score) { - return isViableEntryPoint(BB) ? Score : 0; + return mayExtractBlock(BB) ? Score : 0; } /// These scores should be lower than the score for predecessor blocks, @@ -333,21 +421,23 @@ public: OutliningRegion(OutliningRegion &&) = default; OutliningRegion &operator=(OutliningRegion &&) = default; - static OutliningRegion create(BasicBlock &SinkBB, const DominatorTree &DT, - const PostDomTree &PDT) { - OutliningRegion ColdRegion; - + static std::vector<OutliningRegion> create(BasicBlock &SinkBB, + const DominatorTree &DT, + const PostDominatorTree &PDT) { + std::vector<OutliningRegion> Regions; SmallPtrSet<BasicBlock *, 4> RegionBlocks; + Regions.emplace_back(); + OutliningRegion *ColdRegion = &Regions.back(); + auto addBlockToRegion = [&](BasicBlock *BB, unsigned Score) { RegionBlocks.insert(BB); - ColdRegion.Blocks.emplace_back(BB, Score); - assert(RegionBlocks.size() == ColdRegion.Blocks.size() && "Duplicate BB"); + ColdRegion->Blocks.emplace_back(BB, Score); }; // The ancestor farthest-away from SinkBB, and also post-dominated by it. unsigned SinkScore = getEntryPointScore(SinkBB, ScoreForSinkBlock); - ColdRegion.SuggestedEntryPoint = (SinkScore > 0) ? &SinkBB : nullptr; + ColdRegion->SuggestedEntryPoint = (SinkScore > 0) ? &SinkBB : nullptr; unsigned BestScore = SinkScore; // Visit SinkBB's ancestors using inverse DFS. @@ -360,8 +450,8 @@ public: // If the predecessor is cold and has no predecessors, the entire // function must be cold. if (SinkPostDom && pred_empty(&PredBB)) { - ColdRegion.EntireFunctionCold = true; - return ColdRegion; + ColdRegion->EntireFunctionCold = true; + return Regions; } // If SinkBB does not post-dominate a predecessor, do not mark the @@ -376,7 +466,7 @@ public: // considered as entry points before the sink block. unsigned PredScore = getEntryPointScore(PredBB, PredIt.getPathLength()); if (PredScore > BestScore) { - ColdRegion.SuggestedEntryPoint = &PredBB; + ColdRegion->SuggestedEntryPoint = &PredBB; BestScore = PredScore; } @@ -384,9 +474,19 @@ public: ++PredIt; } - // Add SinkBB to the cold region. It's considered as an entry point before - // any sink-successor blocks. - addBlockToRegion(&SinkBB, SinkScore); + // If the sink can be added to the cold region, do so. It's considered as + // an entry point before any sink-successor blocks. + // + // Otherwise, split cold sink-successor blocks using a separate region. + // This satisfies the requirement that all extraction blocks other than the + // first have predecessors within the extraction region. + if (mayExtractBlock(SinkBB)) { + addBlockToRegion(&SinkBB, SinkScore); + } else { + Regions.emplace_back(); + ColdRegion = &Regions.back(); + BestScore = 0; + } // Find all successors of SinkBB dominated by SinkBB using DFS. auto SuccIt = ++df_begin(&SinkBB); @@ -407,7 +507,7 @@ public: unsigned SuccScore = getEntryPointScore(SuccBB, ScoreForSuccBlock); if (SuccScore > BestScore) { - ColdRegion.SuggestedEntryPoint = &SuccBB; + ColdRegion->SuggestedEntryPoint = &SuccBB; BestScore = SuccScore; } @@ -415,7 +515,7 @@ public: ++SuccIt; } - return ColdRegion; + return Regions; } /// Whether this region has nothing to extract. @@ -461,11 +561,7 @@ public: }; } // namespace -bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, - BlockFrequencyInfo *BFI, - TargetTransformInfo &TTI, - DominatorTree &DT, PostDomTree &PDT, - OptimizationRemarkEmitter &ORE) { +bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { bool Changed = false; // The set of cold blocks. @@ -479,17 +575,28 @@ bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, // the first region to contain a block. ReversePostOrderTraversal<Function *> RPOT(&F); + // Calculate domtrees lazily. This reduces compile-time significantly. + std::unique_ptr<DominatorTree> DT; + std::unique_ptr<PostDominatorTree> PDT; + + // Calculate BFI lazily (it's only used to query ProfileSummaryInfo). This + // reduces compile-time significantly. TODO: When we *do* use BFI, we should + // be able to salvage its domtrees instead of recomputing them. + BlockFrequencyInfo *BFI = nullptr; + if (HasProfileSummary) + BFI = GetBFI(F); + + TargetTransformInfo &TTI = GetTTI(F); + OptimizationRemarkEmitter &ORE = (*GetORE)(F); + AssumptionCache *AC = LookupAC(F); + // Find all cold regions. for (BasicBlock *BB : RPOT) { - // Skip blocks which can't be outlined. - if (!mayExtractBlock(*BB)) - continue; - // This block is already part of some outlining region. if (ColdBlocks.count(BB)) continue; - bool Cold = PSI.isColdBlock(BB, BFI) || + bool Cold = (BFI && PSI->isColdBlock(BB, BFI)) || (EnableStaticAnalyis && unlikelyExecuted(*BB)); if (!Cold) continue; @@ -499,28 +606,35 @@ bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, BB->dump(); }); - auto Region = OutliningRegion::create(*BB, DT, PDT); - if (Region.empty()) - continue; + if (!DT) + DT = make_unique<DominatorTree>(F); + if (!PDT) + PDT = make_unique<PostDominatorTree>(F); - if (Region.isEntireFunctionCold()) { - LLVM_DEBUG(dbgs() << "Entire function is cold\n"); - return markEntireFunctionCold(F); - } + auto Regions = OutliningRegion::create(*BB, *DT, *PDT); + for (OutliningRegion &Region : Regions) { + if (Region.empty()) + continue; - // If this outlining region intersects with another, drop the new region. - // - // TODO: It's theoretically possible to outline more by only keeping the - // largest region which contains a block, but the extra bookkeeping to do - // this is tricky/expensive. - bool RegionsOverlap = any_of(Region.blocks(), [&](const BlockTy &Block) { - return !ColdBlocks.insert(Block.first).second; - }); - if (RegionsOverlap) - continue; + if (Region.isEntireFunctionCold()) { + LLVM_DEBUG(dbgs() << "Entire function is cold\n"); + return markFunctionCold(F); + } + + // If this outlining region intersects with another, drop the new region. + // + // TODO: It's theoretically possible to outline more by only keeping the + // largest region which contains a block, but the extra bookkeeping to do + // this is tricky/expensive. + bool RegionsOverlap = any_of(Region.blocks(), [&](const BlockTy &Block) { + return !ColdBlocks.insert(Block.first).second; + }); + if (RegionsOverlap) + continue; - OutliningWorklist.emplace_back(std::move(Region)); - ++NumColdRegionsFound; + OutliningWorklist.emplace_back(std::move(Region)); + ++NumColdRegionsFound; + } } // Outline single-entry cold regions, splitting up larger regions as needed. @@ -529,26 +643,17 @@ bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, OutliningRegion Region = OutliningWorklist.pop_back_val(); assert(!Region.empty() && "Empty outlining region in worklist"); do { - BlockSequence SubRegion = Region.takeSingleEntrySubRegion(DT); - if (!isProfitableToOutline(SubRegion, TTI)) { - LLVM_DEBUG({ - dbgs() << "Skipping outlining; not profitable to outline\n"; - SubRegion[0]->dump(); - }); - continue; - } - + BlockSequence SubRegion = Region.takeSingleEntrySubRegion(*DT); LLVM_DEBUG({ dbgs() << "Hot/cold splitting attempting to outline these blocks:\n"; for (BasicBlock *BB : SubRegion) BB->dump(); }); - Function *Outlined = - extractColdRegion(SubRegion, DT, BFI, TTI, ORE, OutlinedFunctionID); + Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC, + OutlinedFunctionID); if (Outlined) { ++OutlinedFunctionID; - OutlinedFunctions.insert(Outlined); Changed = true; } } while (!Region.empty()); @@ -559,20 +664,31 @@ bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, bool HotColdSplitting::run(Module &M) { bool Changed = false; - OutlinedFunctions.clear(); - for (auto &F : M) { + bool HasProfileSummary = (M.getProfileSummary(/* IsCS */ false) != nullptr); + for (auto It = M.begin(), End = M.end(); It != End; ++It) { + Function &F = *It; + + // Do not touch declarations. + if (F.isDeclaration()) + continue; + + // Do not modify `optnone` functions. + if (F.hasOptNone()) + continue; + + // Detect inherently cold functions and mark them as such. + if (isFunctionCold(F)) { + Changed |= markFunctionCold(F); + continue; + } + if (!shouldOutlineFrom(F)) { LLVM_DEBUG(llvm::dbgs() << "Skipping " << F.getName() << "\n"); continue; } + LLVM_DEBUG(llvm::dbgs() << "Outlining in " << F.getName() << "\n"); - DominatorTree DT(F); - PostDomTree PDT(F); - PDT.recalculate(F); - BlockFrequencyInfo *BFI = GetBFI(F); - TargetTransformInfo &TTI = GetTTI(F); - OptimizationRemarkEmitter &ORE = (*GetORE)(F); - Changed |= outlineColdRegions(F, *PSI, BFI, TTI, DT, PDT, ORE); + Changed |= outlineColdRegions(F, HasProfileSummary); } return Changed; } @@ -594,17 +710,21 @@ bool HotColdSplittingLegacyPass::runOnModule(Module &M) { ORE.reset(new OptimizationRemarkEmitter(&F)); return *ORE.get(); }; + auto LookupAC = [this](Function &F) -> AssumptionCache * { + if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) + return ACT->lookupAssumptionCache(F); + return nullptr; + }; - return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M); + return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M); } PreservedAnalyses HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - std::function<AssumptionCache &(Function &)> GetAssumptionCache = - [&FAM](Function &F) -> AssumptionCache & { - return FAM.getResult<AssumptionAnalysis>(F); + auto LookupAC = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult<AssumptionAnalysis>(F); }; auto GBFI = [&FAM](Function &F) { @@ -625,7 +745,7 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M)) + if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/lib/Transforms/IPO/IPConstantPropagation.cpp b/lib/Transforms/IPO/IPConstantPropagation.cpp index 7d55ebecbf92..7dc4d9ee9e34 100644 --- a/lib/Transforms/IPO/IPConstantPropagation.cpp +++ b/lib/Transforms/IPO/IPConstantPropagation.cpp @@ -1,9 +1,8 @@ //===-- IPConstantPropagation.cpp - Propagate constants through calls -----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -62,32 +61,55 @@ static bool PropagateConstantsIntoArguments(Function &F) { // Ignore blockaddress uses. if (isa<BlockAddress>(UR)) continue; - // Used by a non-instruction, or not the callee of a function, do not - // transform. - if (!isa<CallInst>(UR) && !isa<InvokeInst>(UR)) + // If no abstract call site was created we did not understand the use, bail. + AbstractCallSite ACS(&U); + if (!ACS) return false; - CallSite CS(cast<Instruction>(UR)); - if (!CS.isCallee(&U)) + // Mismatched argument count is undefined behavior. Simply bail out to avoid + // handling of such situations below (avoiding asserts/crashes). + unsigned NumActualArgs = ACS.getNumArgOperands(); + if (F.isVarArg() ? ArgumentConstants.size() > NumActualArgs + : ArgumentConstants.size() != NumActualArgs) return false; // Check out all of the potentially constant arguments. Note that we don't // inspect varargs here. - CallSite::arg_iterator AI = CS.arg_begin(); Function::arg_iterator Arg = F.arg_begin(); - for (unsigned i = 0, e = ArgumentConstants.size(); i != e; - ++i, ++AI, ++Arg) { + for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++Arg) { // If this argument is known non-constant, ignore it. if (ArgumentConstants[i].second) continue; - Constant *C = dyn_cast<Constant>(*AI); + Value *V = ACS.getCallArgOperand(i); + Constant *C = dyn_cast_or_null<Constant>(V); + + // Mismatched argument type is undefined behavior. Simply bail out to avoid + // handling of such situations below (avoiding asserts/crashes). + if (C && Arg->getType() != C->getType()) + return false; + + // We can only propagate thread independent values through callbacks. + // This is different to direct/indirect call sites because for them we + // know the thread executing the caller and callee is the same. For + // callbacks this is not guaranteed, thus a thread dependent value could + // be different for the caller and callee, making it invalid to propagate. + if (C && ACS.isCallbackCall() && C->isThreadDependent()) { + // Argument became non-constant. If all arguments are non-constant now, + // give up on this function. + if (++NumNonconstant == ArgumentConstants.size()) + return false; + + ArgumentConstants[i].second = true; + continue; + } + if (C && ArgumentConstants[i].first == nullptr) { ArgumentConstants[i].first = C; // First constant seen. } else if (C && ArgumentConstants[i].first == C) { // Still the constant value we think it is. - } else if (*AI == &*Arg) { + } else if (V == &*Arg) { // Ignore recursive calls passing argument down. } else { // Argument became non-constant. If all arguments are non-constant now, diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index 973382e2b097..34db75dd8b03 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -1,9 +1,8 @@ //===-- IPO.cpp -----------------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -46,6 +45,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeLowerTypeTestsPass(Registry); initializeMergeFunctionsPass(Registry); initializePartialInlinerLegacyPassPass(Registry); + initializeAttributorLegacyPassPass(Registry); initializePostOrderFunctionAttrsLegacyPassPass(Registry); initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry); initializePruneEHPass(Registry); diff --git a/lib/Transforms/IPO/InferFunctionAttrs.cpp b/lib/Transforms/IPO/InferFunctionAttrs.cpp index 470f97b8ba61..7f5511e008e1 100644 --- a/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -1,9 +1,8 @@ //===- InferFunctionAttrs.cpp - Infer implicit function attributes --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -26,7 +25,7 @@ static bool inferAllPrototypeAttributes(Module &M, for (Function &F : M.functions()) // We only infer things using the prototype and the name; we don't need // definitions. - if (F.isDeclaration() && !F.hasFnAttribute((Attribute::OptimizeNone))) + if (F.isDeclaration() && !F.hasOptNone()) Changed |= inferLibFuncAttributes(F, TLI); return Changed; diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index 82bba1e5c93b..efb71b73cbb7 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -1,9 +1,8 @@ //===- InlineSimple.cpp - Code to perform simple function inlining --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -69,9 +68,9 @@ public: [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; - return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, - /*GetBFI=*/None, PSI, - RemarksEnabled ? &ORE : nullptr); + return llvm::getInlineCost( + cast<CallBase>(*CS.getInstruction()), Params, TTI, GetAssumptionCache, + /*GetBFI=*/None, PSI, RemarksEnabled ? &ORE : nullptr); } bool runOnSCC(CallGraphSCC &SCC) override; diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 66a6f80f31e4..945f8affae6e 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -1,9 +1,8 @@ //===- Inliner.cpp - Code common to all inliners --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -672,7 +671,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, LLVM_DEBUG(dbgs() << " -> Deleting dead call: " << *Instr << "\n"); // Update the call graph by deleting the edge from Callee to Caller. setInlineRemark(CS, "trivially dead"); - CG[Caller]->removeCallEdgeFor(CS); + CG[Caller]->removeCallEdgeFor(*cast<CallBase>(CS.getInstruction())); Instr->eraseFromParent(); ++NumCallsDeleted; } else { @@ -974,7 +973,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N = *CG.lookup(F); if (CG.lookupSCC(N) != C) continue; - if (F.hasFnAttribute(Attribute::OptimizeNone)) { + if (F.hasOptNone()) { setInlineRemark(Calls[i].first, "optnone attribute"); continue; } @@ -1006,8 +1005,12 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, auto GetInlineCost = [&](CallSite CS) { Function &Callee = *CS.getCalledFunction(); auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); - return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI}, - PSI, &ORE); + bool RemarksEnabled = + Callee.getContext().getDiagHandlerPtr()->isMissedOptRemarkEnabled( + DEBUG_TYPE); + return getInlineCost(cast<CallBase>(*CS.getInstruction()), Params, + CalleeTTI, GetAssumptionCache, {GetBFI}, PSI, + RemarksEnabled ? &ORE : nullptr); }; // Now process as many calls as we have within this caller in the sequnece. diff --git a/lib/Transforms/IPO/Internalize.cpp b/lib/Transforms/IPO/Internalize.cpp index a6542d28dfd8..2e269604e379 100644 --- a/lib/Transforms/IPO/Internalize.cpp +++ b/lib/Transforms/IPO/Internalize.cpp @@ -1,9 +1,8 @@ //===-- Internalize.cpp - Mark functions internal -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -28,11 +27,11 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LineIterator.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include <fstream> -#include <set> using namespace llvm; #define DEBUG_TYPE "internalize" @@ -73,18 +72,15 @@ private: void LoadFile(StringRef Filename) { // Load the APIFile... - std::ifstream In(Filename.data()); - if (!In.good()) { + ErrorOr<std::unique_ptr<MemoryBuffer>> Buf = + MemoryBuffer::getFile(Filename); + if (!Buf) { errs() << "WARNING: Internalize couldn't load file '" << Filename << "'! Continuing as if it's empty.\n"; return; // Just continue as if the file were empty } - while (In) { - std::string Symbol; - In >> Symbol; - if (!Symbol.empty()) - ExternalNames.insert(Symbol); - } + for (line_iterator I(*Buf->get(), true), E; I != E; ++I) + ExternalNames.insert(*I); } }; } // end anonymous namespace @@ -114,7 +110,7 @@ bool InternalizePass::shouldPreserveGV(const GlobalValue &GV) { } bool InternalizePass::maybeInternalize( - GlobalValue &GV, const std::set<const Comdat *> &ExternalComdats) { + GlobalValue &GV, const DenseSet<const Comdat *> &ExternalComdats) { if (Comdat *C = GV.getComdat()) { if (ExternalComdats.count(C)) return false; @@ -141,7 +137,7 @@ bool InternalizePass::maybeInternalize( // If GV is part of a comdat and is externally visible, keep track of its // comdat so that we don't internalize any of its members. void InternalizePass::checkComdatVisibility( - GlobalValue &GV, std::set<const Comdat *> &ExternalComdats) { + GlobalValue &GV, DenseSet<const Comdat *> &ExternalComdats) { Comdat *C = GV.getComdat(); if (!C) return; @@ -158,7 +154,7 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { collectUsedGlobalVariables(M, Used, false); // Collect comdat visiblity information for the module. - std::set<const Comdat *> ExternalComdats; + DenseSet<const Comdat *> ExternalComdats; if (!M.getComdatSymbolTable().empty()) { for (Function &F : M) checkComdatVisibility(F, ExternalComdats); diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 733235d45a09..91c7b5f5f135 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -1,9 +1,8 @@ //===- LoopExtractor.cpp - Extract each loop into a new function ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -15,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -51,6 +51,7 @@ namespace { AU.addRequiredID(LoopSimplifyID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); + AU.addUsedIfAvailable<AssumptionCacheTracker>(); } }; } @@ -139,7 +140,10 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { if (ShouldExtractLoop) { if (NumLoops == 0) return Changed; --NumLoops; - CodeExtractor Extractor(DT, *L); + AssumptionCache *AC = nullptr; + if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) + AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent()); + CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); if (Extractor.extractCodeRegion() != nullptr) { Changed = true; // After extraction, the loop is replaced by a function call, so diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index 87c65db09517..f7371284f47e 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1,9 +1,8 @@ //===- LowerTypeTests.cpp - type metadata lowering pass -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -549,10 +548,10 @@ ByteArrayInfo *LowerTypeTestsModule::createByteArray(BitSetInfo &BSI) { } void LowerTypeTestsModule::allocateByteArrays() { - std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(), - [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) { - return BAI1.BitSize > BAI2.BitSize; - }); + llvm::stable_sort(ByteArrayInfos, + [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) { + return BAI1.BitSize > BAI2.BitSize; + }); std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size()); @@ -619,7 +618,7 @@ Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B, } Value *ByteAddr = B.CreateGEP(Int8Ty, ByteArray, BitOffset); - Value *Byte = B.CreateLoad(ByteAddr); + Value *Byte = B.CreateLoad(Int8Ty, ByteAddr); Value *ByteAndMask = B.CreateAnd(Byte, ConstantExpr::getPtrToInt(TIL.BitMask, Int8Ty)); @@ -1553,11 +1552,10 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet( // Order the sets of indices by size. The GlobalLayoutBuilder works best // when given small index sets first. - std::stable_sort( - TypeMembers.begin(), TypeMembers.end(), - [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) { - return O1.size() < O2.size(); - }); + llvm::stable_sort(TypeMembers, [](const std::set<uint64_t> &O1, + const std::set<uint64_t> &O2) { + return O1.size() < O2.size(); + }); // Create a GlobalLayoutBuilder and provide it with index sets as layout // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as @@ -1693,6 +1691,14 @@ void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) { } bool LowerTypeTestsModule::lower() { + // If only some of the modules were split, we cannot correctly perform + // this transformation. We already checked for the presense of type tests + // with partially split modules during the thin link, and would have emitted + // an error if any were found, so here we can simply return. + if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || + (ImportSummary && ImportSummary->partiallySplitLTOUnits())) + return false; + Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); Function *ICallBranchFunnelFunc = @@ -1702,13 +1708,6 @@ bool LowerTypeTestsModule::lower() { !ExportSummary && !ImportSummary) return false; - // If only some of the modules were split, we cannot correctly handle - // code that contains type tests. - if (TypeTestFunc && !TypeTestFunc->use_empty() && - ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || - (ImportSummary && ImportSummary->partiallySplitLTOUnits()))) - report_fatal_error("inconsistent LTO Unit splitting with llvm.type.test"); - if (ImportSummary) { if (TypeTestFunc) { for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 11efe95b10d4..3a08069dcd4a 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -1,9 +1,8 @@ //===- MergeFunctions.cpp - Merge identical functions ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -190,8 +189,6 @@ public: void replaceBy(Function *G) const { F = G; } - - void release() { F = nullptr; } }; /// MergeFunctions finds functions which will generate identical machine code, @@ -281,8 +278,8 @@ private: // Replace G with an alias to F (deleting function G) void writeAlias(Function *F, Function *G); - // Replace G with an alias to F if possible, or a thunk to F if - // profitable. Returns false if neither is the case. + // Replace G with an alias to F if possible, or a thunk to F if possible. + // Returns false if neither is the case. bool writeThunkOrAlias(Function *F, Function *G); /// Replace function F with function G in the function tree. @@ -383,6 +380,11 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { } #endif +/// Check whether \p F is eligible for function merging. +static bool isEligibleForMerging(Function &F) { + return !F.isDeclaration() && !F.hasAvailableExternallyLinkage(); +} + bool MergeFunctions::runOnModule(Module &M) { if (skipModule(M)) return false; @@ -394,17 +396,12 @@ bool MergeFunctions::runOnModule(Module &M) { std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> HashedFuncs; for (Function &Func : M) { - if (!Func.isDeclaration() && !Func.hasAvailableExternallyLinkage()) { + if (isEligibleForMerging(Func)) { HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func}); } } - std::stable_sort( - HashedFuncs.begin(), HashedFuncs.end(), - [](const std::pair<FunctionComparator::FunctionHash, Function *> &a, - const std::pair<FunctionComparator::FunctionHash, Function *> &b) { - return a.first < b.first; - }); + llvm::stable_sort(HashedFuncs, less_first()); auto S = HashedFuncs.begin(); for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) { @@ -654,12 +651,16 @@ void MergeFunctions::filterInstsUnrelatedToPDI( LLVM_DEBUG(dbgs() << " }\n"); } -// Don't merge tiny functions using a thunk, since it can just end up -// making the function larger. -static bool isThunkProfitable(Function * F) { +/// Whether this function may be replaced by a forwarding thunk. +static bool canCreateThunkFor(Function *F) { + if (F->isVarArg()) + return false; + + // Don't merge tiny functions using a thunk, since it can just end up + // making the function larger. if (F->size() == 1) { if (F->front().size() <= 2) { - LLVM_DEBUG(dbgs() << "isThunkProfitable: " << F->getName() + LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName() << " is too small to bother creating a thunk for\n"); return false; } @@ -695,6 +696,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { } else { NewG = Function::Create(G->getFunctionType(), G->getLinkage(), G->getAddressSpace(), "", G->getParent()); + NewG->setComdat(G->getComdat()); BB = BasicBlock::Create(F->getContext(), "", NewG); } @@ -787,7 +789,7 @@ bool MergeFunctions::writeThunkOrAlias(Function *F, Function *G) { writeAlias(F, G); return true; } - if (isThunkProfitable(F)) { + if (canCreateThunkFor(F)) { writeThunk(F, G); return true; } @@ -802,9 +804,9 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // Both writeThunkOrAlias() calls below must succeed, either because we can // create aliases for G and NewF, or because a thunk for F is profitable. // F here has the same signature as NewF below, so that's what we check. - if (!isThunkProfitable(F) && (!canCreateAliasFor(F) || !canCreateAliasFor(G))) { + if (!canCreateThunkFor(F) && + (!canCreateAliasFor(F) || !canCreateAliasFor(G))) return; - } // Make them both thunks to the same internal function. Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), @@ -944,25 +946,7 @@ void MergeFunctions::remove(Function *F) { // For each instruction used by the value, remove() the function that contains // the instruction. This should happen right before a call to RAUW. void MergeFunctions::removeUsers(Value *V) { - std::vector<Value *> Worklist; - Worklist.push_back(V); - SmallPtrSet<Value*, 8> Visited; - Visited.insert(V); - while (!Worklist.empty()) { - Value *V = Worklist.back(); - Worklist.pop_back(); - - for (User *U : V->users()) { - if (Instruction *I = dyn_cast<Instruction>(U)) { - remove(I->getFunction()); - } else if (isa<GlobalValue>(U)) { - // do nothing - } else if (Constant *C = dyn_cast<Constant>(U)) { - for (User *UU : C->users()) { - if (!Visited.insert(UU).second) - Worklist.push_back(UU); - } - } - } - } + for (User *U : V->users()) + if (auto *I = dyn_cast<Instruction>(U)) + remove(I->getFunction()); } diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index da214a1d3b44..733782e8764d 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -1,9 +1,8 @@ //===- PartialInlining.cpp - Inline parts of functions --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -182,11 +181,11 @@ struct FunctionOutliningMultiRegionInfo { // Container for outline regions struct OutlineRegionInfo { - OutlineRegionInfo(SmallVector<BasicBlock *, 8> Region, + OutlineRegionInfo(ArrayRef<BasicBlock *> Region, BasicBlock *EntryBlock, BasicBlock *ExitBlock, BasicBlock *ReturnBlock) - : Region(Region), EntryBlock(EntryBlock), ExitBlock(ExitBlock), - ReturnBlock(ReturnBlock) {} + : Region(Region.begin(), Region.end()), EntryBlock(EntryBlock), + ExitBlock(ExitBlock), ReturnBlock(ReturnBlock) {} SmallVector<BasicBlock *, 8> Region; BasicBlock *EntryBlock; BasicBlock *ExitBlock; @@ -200,10 +199,12 @@ struct PartialInlinerImpl { PartialInlinerImpl( std::function<AssumptionCache &(Function &)> *GetAC, + function_ref<AssumptionCache *(Function &)> LookupAC, std::function<TargetTransformInfo &(Function &)> *GTTI, Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI, ProfileSummaryInfo *ProfSI) - : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} + : GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC), + GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} bool run(Module &M); // Main part of the transformation that calls helper functions to find @@ -223,9 +224,11 @@ struct PartialInlinerImpl { // Two constructors, one for single region outlining, the other for // multi-region outlining. FunctionCloner(Function *F, FunctionOutliningInfo *OI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref<AssumptionCache *(Function &)> LookupAC); FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref<AssumptionCache *(Function &)> LookupAC); ~FunctionCloner(); // Prepare for function outlining: making sure there is only @@ -261,11 +264,13 @@ struct PartialInlinerImpl { std::unique_ptr<FunctionOutliningMultiRegionInfo> ClonedOMRI = nullptr; std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr; OptimizationRemarkEmitter &ORE; + function_ref<AssumptionCache *(Function &)> LookupAC; }; private: int NumPartialInlining = 0; std::function<AssumptionCache &(Function &)> *GetAssumptionCache; + function_ref<AssumptionCache *(Function &)> LookupAssumptionCache; std::function<TargetTransformInfo &(Function &)> *GetTTI; Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI; ProfileSummaryInfo *PSI; @@ -366,12 +371,17 @@ struct PartialInlinerLegacyPass : public ModulePass { return ACT->getAssumptionCache(F); }; + auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * { + return ACT->lookupAssumptionCache(F); + }; + std::function<TargetTransformInfo &(Function &)> GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & { return TTIWP->getTTI(F); }; - return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI) + return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, + &GetTTI, NoneType::None, PSI) .run(M); } }; @@ -525,7 +535,6 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, // assert(ReturnBlock && "ReturnBlock is NULL somehow!"); FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegInfo( DominateVector, DominateVector.front(), ExitBlock, ReturnBlock); - RegInfo.Region = DominateVector; OutliningInfo->ORI.push_back(RegInfo); #ifndef NDEBUG if (TracePartialInlining) { @@ -763,8 +772,13 @@ bool PartialInlinerImpl::shouldPartialInline( Function *Caller = CS.getCaller(); auto &CalleeTTI = (*GetTTI)(*Callee); - InlineCost IC = getInlineCost(CS, getInlineParams(), CalleeTTI, - *GetAssumptionCache, GetBFI, PSI, &ORE); + bool RemarksEnabled = + Callee->getContext().getDiagHandlerPtr()->isMissedOptRemarkEnabled( + DEBUG_TYPE); + assert(Call && "invalid callsite for partial inline"); + InlineCost IC = getInlineCost(cast<CallBase>(*Call), getInlineParams(), + CalleeTTI, *GetAssumptionCache, GetBFI, PSI, + RemarksEnabled ? &ORE : nullptr); if (IC.isAlways()) { ORE.emit([&]() { @@ -798,7 +812,7 @@ bool PartialInlinerImpl::shouldPartialInline( const DataLayout &DL = Caller->getParent()->getDataLayout(); // The savings of eliminating the call: - int NonWeightedSavings = getCallsiteCost(CS, DL); + int NonWeightedSavings = getCallsiteCost(cast<CallBase>(*Call), DL); BlockFrequency NormWeightedSavings(NonWeightedSavings); // Weighted saving is smaller than weighted cost, return false @@ -855,12 +869,12 @@ int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { continue; if (CallInst *CI = dyn_cast<CallInst>(&I)) { - InlineCost += getCallsiteCost(CallSite(CI), DL); + InlineCost += getCallsiteCost(*CI, DL); continue; } if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { - InlineCost += getCallsiteCost(CallSite(II), DL); + InlineCost += getCallsiteCost(*II, DL); continue; } @@ -949,8 +963,9 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap( } PartialInlinerImpl::FunctionCloner::FunctionCloner( - Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, + function_ref<AssumptionCache *(Function &)> LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOI = llvm::make_unique<FunctionOutliningInfo>(); // Clone the function, so that we can hack away on it. @@ -973,8 +988,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningMultiRegionInfo *OI, - OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + OptimizationRemarkEmitter &ORE, + function_ref<AssumptionCache *(Function &)> LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>(); // Clone the function, so that we can hack away on it. @@ -1112,7 +1128,9 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region); CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false); + ClonedFuncBFI.get(), &BPI, + LookupAC(*RegionInfo.EntryBlock->getParent()), + /* AllowVarargs */ false); CE.findInputsOutputs(Inputs, Outputs, Sinks); @@ -1194,7 +1212,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // Extract the body of the if. Function *OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, + ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc), /* AllowVarargs */ true) .extractCodeRegion(); @@ -1258,7 +1276,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI = computeOutliningColdRegionsInfo(F, ORE); if (OMRI) { - FunctionCloner Cloner(F, OMRI.get(), ORE); + FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache); #ifndef NDEBUG if (TracePartialInlining) { @@ -1291,7 +1309,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (!OI) return {false, nullptr}; - FunctionCloner Cloner(F, OI.get(), ORE); + FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache); Cloner.NormalizeReturnBlock(); Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining(); @@ -1485,6 +1503,10 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, return FAM.getResult<AssumptionAnalysis>(F); }; + auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult<AssumptionAnalysis>(F); + }; + std::function<BlockFrequencyInfo &(Function &)> GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { return FAM.getResult<BlockFrequencyAnalysis>(F); @@ -1497,7 +1519,8 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI) + if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI, + {GetBFI}, PSI) .run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 9764944dc332..3ea77f08fd3c 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -1,9 +1,8 @@ //===- PassManagerBuilder.cpp - Build Standard Pass -----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -31,6 +30,7 @@ #include "llvm/Support/ManagedStatic.h" #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO/InferFunctionAttrs.h" @@ -39,9 +39,13 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/InstSimplifyPass.h" +#include "llvm/Transforms/Scalar/LICM.h" +#include "llvm/Transforms/Scalar/LoopUnrollPass.h" #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Vectorize.h" +#include "llvm/Transforms/Vectorize/LoopVectorize.h" +#include "llvm/Transforms/Vectorize/SLPVectorizer.h" using namespace llvm; @@ -50,14 +54,6 @@ static cl::opt<bool> cl::ZeroOrMore, cl::desc("Run Partial inlinining pass")); static cl::opt<bool> - RunLoopVectorization("vectorize-loops", cl::Hidden, - cl::desc("Run the Loop vectorization passes")); - -static cl::opt<bool> -RunSLPVectorization("vectorize-slp", cl::Hidden, - cl::desc("Run the SLP vectorization passes")); - -static cl::opt<bool> UseGVNAfterVectorization("use-gvn-after-vectorization", cl::init(false), cl::Hidden, cl::desc("Run GVN instead of Early CSE after vectorization passes")); @@ -73,12 +69,6 @@ RunLoopRerolling("reroll-loops", cl::Hidden, static cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, cl::desc("Run the NewGVN pass")); -static cl::opt<bool> -RunSLPAfterLoopVectorization("run-slp-after-loop-vectorization", - cl::init(true), cl::Hidden, - cl::desc("Run the SLP vectorizer (and BB vectorizer) after the Loop " - "vectorizer instead of before")); - // Experimental option to use CFL-AA enum class CFLAAType { None, Steensgaard, Andersen, Both }; static cl::opt<CFLAAType> @@ -104,23 +94,13 @@ static cl::opt<bool> EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, cl::desc("Enable preparation for ThinLTO.")); +static cl::opt<bool> + EnablePerformThinLTO("perform-thinlto", cl::init(false), cl::Hidden, + cl::desc("Enable performing ThinLTO.")); + cl::opt<bool> EnableHotColdSplit("hot-cold-split", cl::init(false), cl::Hidden, cl::desc("Enable hot-cold splitting pass")); - -static cl::opt<bool> RunPGOInstrGen( - "profile-generate", cl::init(false), cl::Hidden, - cl::desc("Enable PGO instrumentation.")); - -static cl::opt<std::string> - PGOOutputFile("profile-generate-file", cl::init(""), cl::Hidden, - cl::desc("Specify the path of profile data file.")); - -static cl::opt<std::string> RunPGOInstrUse( - "profile-use", cl::init(""), cl::Hidden, cl::value_desc("filename"), - cl::desc("Enable use phase of PGO instrumentation and specify the path " - "of profile data file")); - static cl::opt<bool> UseLoopVersioningLICM( "enable-loop-versioning-licm", cl::init(false), cl::Hidden, cl::desc("Enable the experimental Loop Versioning LICM pass")); @@ -134,10 +114,6 @@ static cl::opt<int> PreInlineThreshold( cl::desc("Control the amount of inlining in pre-instrumentation inliner " "(default = 75)")); -static cl::opt<bool> EnableEarlyCSEMemSSA( - "enable-earlycse-memssa", cl::init(true), cl::Hidden, - cl::desc("Enable the EarlyCSE w/ MemorySSA pass (default = on)")); - static cl::opt<bool> EnableGVNHoist( "enable-gvn-hoist", cl::init(false), cl::Hidden, cl::desc("Enable the GVN hoisting pass (default = off)")); @@ -156,10 +132,21 @@ static cl::opt<bool> EnableGVNSink( "enable-gvn-sink", cl::init(false), cl::Hidden, cl::desc("Enable the GVN sinking pass (default = off)")); +// This option is used in simplifying testing SampleFDO optimizations for +// profile loading. static cl::opt<bool> EnableCHR("enable-chr", cl::init(true), cl::Hidden, cl::desc("Enable control height reduction optimization (CHR)")); +cl::opt<bool> FlattenedProfileUsed( + "flattened-profile-used", cl::init(false), cl::Hidden, + cl::desc("Indicate the sample profile being used is flattened, i.e., " + "no inline hierachy exists in the profile. ")); + +cl::opt<bool> EnableOrderFileInstrumentation( + "enable-order-file-instrumentation", cl::init(false), cl::Hidden, + cl::desc("Enable order file instrumentation (default = off)")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -167,19 +154,26 @@ PassManagerBuilder::PassManagerBuilder() { Inliner = nullptr; DisableUnrollLoops = false; SLPVectorize = RunSLPVectorization; - LoopVectorize = RunLoopVectorization; + LoopVectorize = EnableLoopVectorization; + LoopsInterleaved = EnableLoopInterleaving; RerollLoops = RunLoopRerolling; NewGVN = RunNewGVN; + LicmMssaOptCap = SetLicmMssaOptCap; + LicmMssaNoAccForPromotionCap = SetLicmMssaNoAccForPromotionCap; DisableGVNLoadPRE = false; + ForgetAllSCEVInLoopUnroll = ForgetSCEVInLoopUnroll; VerifyInput = false; VerifyOutput = false; MergeFunctions = false; PrepareForLTO = false; - EnablePGOInstrGen = RunPGOInstrGen; - PGOInstrGen = PGOOutputFile; - PGOInstrUse = RunPGOInstrUse; + EnablePGOInstrGen = false; + EnablePGOCSInstrGen = false; + EnablePGOCSInstrUse = false; + PGOInstrGen = ""; + PGOInstrUse = ""; + PGOSampleUse = ""; PrepareForThinLTO = EnablePrepareForThinLTO; - PerformThinLTO = false; + PerformThinLTO = EnablePerformThinLTO; DivergentTarget = false; } @@ -272,13 +266,19 @@ void PassManagerBuilder::populateFunctionPassManager( } // Do PGO instrumentation generation or use pass as the option specified. -void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { - if (!EnablePGOInstrGen && PGOInstrUse.empty() && PGOSampleUse.empty()) +void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM, + bool IsCS = false) { + if (IsCS) { + if (!EnablePGOCSInstrGen && !EnablePGOCSInstrUse) + return; + } else if (!EnablePGOInstrGen && PGOInstrUse.empty() && PGOSampleUse.empty()) return; + // Perform the preinline and cleanup passes for O1 and above. // And avoid doing them if optimizing for size. + // We will not do this inline for context sensitive PGO (when IsCS is true). if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner && - PGOSampleUse.empty()) { + PGOSampleUse.empty() && !IsCS) { // Create preinline pass. We construct an InlineParams object and specify // the threshold here to avoid the command line options of the regular // inliner to influence pre-inlining. The only fields of InlineParams we @@ -296,22 +296,23 @@ void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { MPM.add(createInstructionCombiningPass()); // Combine silly seq's addExtensionsToPM(EP_Peephole, MPM); } - if (EnablePGOInstrGen) { - MPM.add(createPGOInstrumentationGenLegacyPass()); + if ((EnablePGOInstrGen && !IsCS) || (EnablePGOCSInstrGen && IsCS)) { + MPM.add(createPGOInstrumentationGenLegacyPass(IsCS)); // Add the profile lowering pass. InstrProfOptions Options; if (!PGOInstrGen.empty()) Options.InstrProfileOutput = PGOInstrGen; Options.DoCounterPromotion = true; + Options.UseBFIInPromotion = IsCS; MPM.add(createLoopRotatePass()); - MPM.add(createInstrProfilingLegacyPass(Options)); + MPM.add(createInstrProfilingLegacyPass(Options, IsCS)); } if (!PGOInstrUse.empty()) - MPM.add(createPGOInstrumentationUseLegacyPass(PGOInstrUse)); + MPM.add(createPGOInstrumentationUseLegacyPass(PGOInstrUse, IsCS)); // Indirect call promotion that promotes intra-module targets only. // For ThinLTO this is done earlier due to interactions with globalopt // for imported functions. We don't run this at -O0. - if (OptLevel > 0) + if (OptLevel > 0 && !IsCS) MPM.add( createPGOIndirectCallPromotionLegacyPass(false, !PGOSampleUse.empty())); } @@ -320,7 +321,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses( // Start of function pass. // Break up aggregate allocas, using SSAUpdater. MPM.add(createSROAPass()); - MPM.add(createEarlyCSEPass(EnableEarlyCSEMemSSA)); // Catch trivial redundancies + MPM.add(createEarlyCSEPass(true /* Enable mem-ssa. */)); // Catch trivial redundancies if (EnableGVNHoist) MPM.add(createGVNHoistPass()); if (EnableGVNSink) { @@ -359,7 +360,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses( } // Rotate Loop - disable header duplication at -Oz MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); - MPM.add(createLICMPass()); // Hoist loop invariants + MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); if (EnableSimpleLoopUnswitch) MPM.add(createSimpleLoopUnswitchLegacyPass()); else @@ -378,8 +379,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (EnableLoopInterchange) MPM.add(createLoopInterchangePass()); // Interchange loops - MPM.add(createSimpleLoopUnrollPass(OptLevel, - DisableUnrollLoops)); // Unroll small loops + // Unroll small loops + MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, + ForgetAllSCEVInLoopUnroll)); addExtensionsToPM(EP_LoopOptimizerEnd, MPM); // This ends the loop pass pipelines. @@ -403,14 +405,12 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); MPM.add(createDeadStoreEliminationPass()); // Delete dead stores - MPM.add(createLICMPass()); + MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); addExtensionsToPM(EP_ScalarOptimizerLate, MPM); if (RerollLoops) MPM.add(createLoopRerollPass()); - if (!RunSLPAfterLoopVectorization && SLPVectorize) - MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. MPM.add(createAggressiveDCEPass()); // Delete dead instructions MPM.add(createCFGSimplificationPass()); // Merge & remove BBs @@ -419,15 +419,23 @@ void PassManagerBuilder::addFunctionSimplificationPasses( addExtensionsToPM(EP_Peephole, MPM); if (EnableCHR && OptLevel >= 3 && - (!PGOInstrUse.empty() || !PGOSampleUse.empty())) + (!PGOInstrUse.empty() || !PGOSampleUse.empty() || EnablePGOCSInstrGen)) MPM.add(createControlHeightReductionLegacyPass()); } void PassManagerBuilder::populateModulePassManager( legacy::PassManagerBase &MPM) { + // Whether this is a default or *LTO pre-link pipeline. The FullLTO post-link + // is handled separately, so just check this is not the ThinLTO post-link. + bool DefaultOrPreLinkPipeline = !PerformThinLTO; + if (!PGOSampleUse.empty()) { MPM.add(createPruneEHPass()); - MPM.add(createSampleProfileLoaderPass(PGOSampleUse)); + // In ThinLTO mode, when flattened profile is used, all the available + // profile information will be annotated in PreLink phase so there is + // no need to load the profile again in PostLink. + if (!(FlattenedProfileUsed && PerformThinLTO)) + MPM.add(createSampleProfileLoaderPass(PGOSampleUse)); } // Allow forcing function attributes as a debugging and tuning aid. @@ -508,6 +516,10 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createIPSCCPPass()); // IP SCCP MPM.add(createCalledValuePropagationPass()); + + // Infer attributes on declarations, call sites, arguments, etc. + MPM.add(createAttributorLegacyPass()); + MPM.add(createGlobalOptimizerPass()); // Optimize out global vars // Promote any localized global vars. MPM.add(createPromoteMemoryToRegisterPass()); @@ -523,9 +535,14 @@ void PassManagerBuilder::populateModulePassManager( // profile annotation in backend more difficult. // PGO instrumentation is added during the compile phase for ThinLTO, do // not run it a second time - if (!PerformThinLTO && !PrepareForThinLTOUsingPGOSampleProfile) + if (DefaultOrPreLinkPipeline && !PrepareForThinLTOUsingPGOSampleProfile) addPGOInstrPasses(MPM); + // Create profile COMDAT variables. Lld linker wants to see all variables + // before the LTO/ThinLTO link since it needs to resolve symbols/comdats. + if (!PerformThinLTO && EnablePGOCSInstrGen) + MPM.add(createPGOInstrumentationGenCreateVarLegacyPass(PGOInstrGen)); + // We add a module alias analysis pass here. In part due to bugs in the // analysis infrastructure this "works" in that the analysis stays alive // for the entire SCC pass run below. @@ -567,6 +584,17 @@ void PassManagerBuilder::populateModulePassManager( // and saves running remaining passes on the eliminated functions. MPM.add(createEliminateAvailableExternallyPass()); + // CSFDO instrumentation and use pass. Don't invoke this for Prepare pass + // for LTO and ThinLTO -- The actual pass will be called after all inlines + // are performed. + // Need to do this after COMDAT variables have been eliminated, + // (i.e. after EliminateAvailableExternallyPass). + if (!(PrepareForLTO || PrepareForThinLTO)) + addPGOInstrPasses(MPM, /* IsCS */ true); + + if (EnableOrderFileInstrumentation) + MPM.add(createInstrOrderFilePass()); + MPM.add(createReversePostOrderFunctionAttrsPass()); // The inliner performs some kind of dead code elimination as it goes, @@ -605,7 +633,7 @@ void PassManagerBuilder::populateModulePassManager( // later might get benefit of no-alias assumption in clone loop. if (UseLoopVersioningLICM) { MPM.add(createLoopVersioningLICMPass()); // Do LoopVersioningLICM - MPM.add(createLICMPass()); // Hoist loop invariants + MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); } // We add a fresh GlobalsModRef run at this point. This is particularly @@ -640,7 +668,7 @@ void PassManagerBuilder::populateModulePassManager( // llvm.loop.distribute=true or when -enable-loop-distribute is specified. MPM.add(createLoopDistributePass()); - MPM.add(createLoopVectorizePass(DisableUnrollLoops, !LoopVectorize)); + MPM.add(createLoopVectorizePass(!LoopsInterleaved, !LoopVectorize)); // Eliminate loads by forwarding stores from the previous iteration to loads // of the current iteration. @@ -662,7 +690,7 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createEarlyCSEPass()); MPM.add(createCorrelatedValuePropagationPass()); addInstructionCombiningPass(MPM); - MPM.add(createLICMPass()); + MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); MPM.add(createCFGSimplificationPass()); addInstructionCombiningPass(MPM); @@ -675,7 +703,7 @@ void PassManagerBuilder::populateModulePassManager( // before SLP vectorization. MPM.add(createCFGSimplificationPass(1, true, true, false, true)); - if (RunSLPAfterLoopVectorization && SLPVectorize) { + if (SLPVectorize) { MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. if (OptLevel > 1 && ExtraVectorizerPasses) { MPM.add(createEarlyCSEPass()); @@ -692,8 +720,9 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createLoopUnrollAndJamPass(OptLevel)); } - MPM.add(createLoopUnrollPass(OptLevel, - DisableUnrollLoops)); // Unroll small loops + // Unroll small loops + MPM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, + ForgetAllSCEVInLoopUnroll)); if (!DisableUnrollLoops) { // LoopUnroll may generate some redundency to cleanup. @@ -703,7 +732,7 @@ void PassManagerBuilder::populateModulePassManager( // unrolled loop is a inner loop, then the prologue will be inside the // outer loop. LICM pass can help to promote the runtime check out if the // checked value is loop invariant. - MPM.add(createLICMPass()); + MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); } MPM.add(createWarnMissedTransformationsPass()); @@ -722,6 +751,11 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createConstantMergePass()); // Merge dup global constants } + // See comment in the new PM for justification of scheduling splitting at + // this stage (\ref buildModuleSimplificationPipeline). + if (EnableHotColdSplit && !(PrepareForLTO || PrepareForThinLTO)) + MPM.add(createHotColdSplittingPass()); + if (MergeFunctions) MPM.add(createMergeFunctionsPass()); @@ -738,9 +772,6 @@ void PassManagerBuilder::populateModulePassManager( // flattening of blocks. MPM.add(createDivRemPairsPass()); - if (EnableHotColdSplit) - MPM.add(createHotColdSplittingPass()); - // LoopSink (and other loop passes since the last simplifyCFG) might have // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. MPM.add(createCFGSimplificationPass()); @@ -793,6 +824,9 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // Attach metadata to indirect call sites indicating the set of functions // they may target at run-time. This should follow IPSCCP. PM.add(createCalledValuePropagationPass()); + + // Infer attributes on declarations, call sites, arguments, etc. + PM.add(createAttributorLegacyPass()); } // Infer attributes about definitions. The readnone attribute in particular is @@ -842,6 +876,9 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createPruneEHPass()); // Remove dead EH info. + // CSFDO instrumentation and use pass. + addPGOInstrPasses(PM, /* IsCS */ true); + // Optimize globals again if we ran the inliner. if (RunInliner) PM.add(createGlobalOptimizerPass()); @@ -859,11 +896,16 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // Break up allocas PM.add(createSROAPass()); - // Run a few AA driven optimizations here and now, to cleanup the code. + // LTO provides additional opportunities for tailcall elimination due to + // link-time inlining, and visibility of nocapture attribute. + PM.add(createTailCallEliminationPass()); + + // Infer attributes on declarations, call sites, arguments, etc. PM.add(createPostOrderFunctionAttrsLegacyPass()); // Add nocapture. + // Run a few AA driven optimizations here and now, to cleanup the code. PM.add(createGlobalsAAWrapperPass()); // IP alias analysis. - PM.add(createLICMPass()); // Hoist loop invariants. + PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. PM.add(NewGVN ? createNewGVNPass() : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. @@ -878,11 +920,13 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); - PM.add(createSimpleLoopUnrollPass(OptLevel, - DisableUnrollLoops)); // Unroll small loops + // Unroll small loops + PM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, + ForgetAllSCEVInLoopUnroll)); PM.add(createLoopVectorizePass(true, !LoopVectorize)); // The vectorizer may have significantly shortened a loop body; unroll again. - PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops)); + PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops, + ForgetAllSCEVInLoopUnroll)); PM.add(createWarnMissedTransformationsPass()); @@ -896,9 +940,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createBitTrackingDCEPass()); // More scalar chains could be vectorized due to more alias information - if (RunSLPAfterLoopVectorization) - if (SLPVectorize) - PM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. + if (SLPVectorize) + PM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. // After vectorization, assume intrinsics may tell us more about pointer // alignments. @@ -913,6 +956,11 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { void PassManagerBuilder::addLateLTOOptimizationPasses( legacy::PassManagerBase &PM) { + // See comment in the new PM for justification of scheduling splitting at + // this stage (\ref buildLTODefaultPipeline). + if (EnableHotColdSplit) + PM.add(createHotColdSplittingPass()); + // Delete basic blocks, which optimization passes may have killed. PM.add(createCFGSimplificationPass()); @@ -968,6 +1016,8 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { if (VerifyInput) PM.add(createVerifierPass()); + addExtensionsToPM(EP_FullLinkTimeOptimizationEarly, PM); + if (OptLevel != 0) addLTOOptimizationPasses(PM); else { @@ -989,6 +1039,8 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { if (OptLevel != 0) addLateLTOOptimizationPasses(PM); + addExtensionsToPM(EP_FullLinkTimeOptimizationLast, PM); + if (VerifyOutput) PM.add(createVerifierPass()); } diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index ae586c017471..cb3915dfb678 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -1,9 +1,8 @@ //===- PruneEH.cpp - Pass which deletes unused exception handlers ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -204,7 +203,8 @@ static bool SimplifyFunction(Function *F, CallGraph &CG) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) if (CallInst *CI = dyn_cast<CallInst>(I++)) - if (CI->doesNotReturn() && !isa<UnreachableInst>(I)) { + if (CI->doesNotReturn() && !CI->isMustTailCall() && + !isa<UnreachableInst>(I)) { // This call calls a function that cannot return. Insert an // unreachable instruction after it and simplify the code. Do this // by splitting the BB, adding the unreachable, then deleting the @@ -242,12 +242,12 @@ static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG) { break; } - if (auto CS = CallSite (&*I)) { - const Function *Callee = CS.getCalledFunction(); + if (auto *Call = dyn_cast<CallBase>(&*I)) { + const Function *Callee = Call->getCalledFunction(); if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) - CGN->removeCallEdgeFor(CS); + CGN->removeCallEdgeFor(*Call); else if (!Callee->isIntrinsic()) - CGN->removeCallEdgeFor(CS); + CGN->removeCallEdgeFor(*Call); } if (!I->use_empty()) diff --git a/lib/Transforms/IPO/SCCP.cpp b/lib/Transforms/IPO/SCCP.cpp index d2c34abfc132..7be3608bd2ec 100644 --- a/lib/Transforms/IPO/SCCP.cpp +++ b/lib/Transforms/IPO/SCCP.cpp @@ -79,6 +79,7 @@ char IPSCCPLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 9f123c2b875e..877d20e72ffc 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -1,9 +1,8 @@ //===- SampleProfile.cpp - Incorporate sample profiles into the IR --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -319,6 +318,14 @@ protected: /// Optimization Remark Emitter used to emit diagnostic remarks. OptimizationRemarkEmitter *ORE = nullptr; + + // Information recorded when we declined to inline a call site + // because we have determined it is too cold is accumulated for + // each callee function. Initially this is just the entry count. + struct NotInlinedProfileInfo { + uint64_t entryCount; + }; + DenseMap<Function *, NotInlinedProfileInfo> notInlinedCallInfo; }; class SampleProfileLoaderLegacyPass : public ModulePass { @@ -745,8 +752,9 @@ bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { // when cost exceeds threshold without checking all IRs in the callee. // The acutal cost does not matter because we only checks isNever() to // see if it is legal to inline the callsite. - InlineCost Cost = getInlineCost(CS, Params, GetTTI(*CalledFunction), GetAC, - None, nullptr, nullptr); + InlineCost Cost = + getInlineCost(cast<CallBase>(*I), Params, GetTTI(*CalledFunction), GetAC, + None, nullptr, nullptr); if (Cost.isNever()) { ORE->emit(OptimizationRemark(DEBUG_TYPE, "Not inline", DLoc, BB) << "incompatible inlining"); @@ -779,6 +787,8 @@ bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { bool SampleProfileLoader::inlineHotFunctions( Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { DenseSet<Instruction *> PromotedInsns; + + DenseMap<Instruction *, const FunctionSamples *> localNotInlinedCallSites; bool Changed = false; while (true) { bool LocalChanged = false; @@ -791,6 +801,8 @@ bool SampleProfileLoader::inlineHotFunctions( if ((isa<CallInst>(I) || isa<InvokeInst>(I)) && !isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(I))) { Candidates.push_back(&I); + if (FS->getEntrySamples() > 0) + localNotInlinedCallSites.try_emplace(&I, FS); if (callsiteIsHot(FS, PSI)) Hot = true; } @@ -823,6 +835,9 @@ bool SampleProfileLoader::inlineHotFunctions( if (CalleeFunctionName == F.getName()) continue; + if (!callsiteIsHot(FS, PSI)) + continue; + const char *Reason = "Callee function not available"; auto R = SymbolMap.find(CalleeFunctionName); if (R != SymbolMap.end() && R->getValue() && @@ -836,8 +851,10 @@ bool SampleProfileLoader::inlineHotFunctions( PromotedInsns.insert(I); // If profile mismatches, we should not attempt to inline DI. if ((isa<CallInst>(DI) || isa<InvokeInst>(DI)) && - inlineCallInstruction(DI)) + inlineCallInstruction(DI)) { + localNotInlinedCallSites.erase(I); LocalChanged = true; + } } else { LLVM_DEBUG(dbgs() << "\nFailed to promote indirect call to " @@ -846,8 +863,10 @@ bool SampleProfileLoader::inlineHotFunctions( } } else if (CalledFunction && CalledFunction->getSubprogram() && !CalledFunction->isDeclaration()) { - if (inlineCallInstruction(I)) + if (inlineCallInstruction(I)) { + localNotInlinedCallSites.erase(I); LocalChanged = true; + } } else if (IsThinLTOPreLink) { findCalleeFunctionSamples(*I)->findInlinedFunctions( InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold()); @@ -859,6 +878,18 @@ bool SampleProfileLoader::inlineHotFunctions( break; } } + + // Accumulate not inlined callsite information into notInlinedSamples + for (const auto &Pair : localNotInlinedCallSites) { + Instruction *I = Pair.getFirst(); + Function *Callee = CallSite(I).getCalledFunction(); + if (!Callee || Callee->isDeclaration()) + continue; + const FunctionSamples *FS = Pair.getSecond(); + auto pair = + notInlinedCallInfo.try_emplace(Callee, NotInlinedProfileInfo{0}); + pair.first->second.entryCount += FS->getEntrySamples(); + } return Changed; } @@ -1299,10 +1330,10 @@ void SampleProfileLoader::propagateWeights(Function &F) { annotateValueSite(*I.getParent()->getParent()->getParent(), I, SortedCallTargets, Sum, IPVK_IndirectCallTarget, SortedCallTargets.size()); - } else if (!dyn_cast<IntrinsicInst>(&I)) { - SmallVector<uint32_t, 1> Weights; - Weights.push_back(BlockWeights[BB]); - I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + } else if (!isa<IntrinsicInst>(&I)) { + I.setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights( + {static_cast<uint32_t>(BlockWeights[BB])})); } } } @@ -1568,8 +1599,9 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, return false; PSI = _PSI; - if (M.getProfileSummary() == nullptr) - M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); + if (M.getProfileSummary(/* IsCS */ false) == nullptr) + M.setProfileSummary(Reader->getSummary().getMD(M.getContext()), + ProfileSummary::PSK_Sample); // Compute the total number of samples collected in this profile. for (const auto &I : Reader->getProfiles()) @@ -1601,6 +1633,12 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, clearFunctionData(); retval |= runOnFunction(F, AM); } + + // Account for cold calls not inlined.... + for (const std::pair<Function *, NotInlinedProfileInfo> &pair : + notInlinedCallInfo) + updateProfileCallee(pair.first, pair.second.entryCount); + return retval; } diff --git a/lib/Transforms/IPO/StripDeadPrototypes.cpp b/lib/Transforms/IPO/StripDeadPrototypes.cpp index 3c3c5dd19d1f..106db3c8bd9d 100644 --- a/lib/Transforms/IPO/StripDeadPrototypes.cpp +++ b/lib/Transforms/IPO/StripDeadPrototypes.cpp @@ -1,9 +1,8 @@ //===-- StripDeadPrototypes.cpp - Remove unused function declarations ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp index c9afb060a91a..67a473612fc1 100644 --- a/lib/Transforms/IPO/StripSymbols.cpp +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -1,9 +1,8 @@ //===- StripSymbols.cpp - Strip symbols and debug info from a module ------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index ba4efb3ff60d..45fd432fd721 100644 --- a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -1,9 +1,8 @@ //=- SyntheticCountsPropagation.cpp - Propagate function counts --*- C++ -*-=// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 510ecb516dc2..24c476376c14 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -1,9 +1,8 @@ //===- ThinLTOBitcodeWriter.cpp - Bitcode writing pass for ThinLTO --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -418,34 +417,53 @@ void splitAndWriteThinLTOBitcode( } } -// Returns whether this module needs to be split because splitting is -// enabled and it uses type metadata. -bool requiresSplit(Module &M) { - // First check if the LTO Unit splitting has been enabled. +// Check if the LTO Unit splitting has been enabled. +bool enableSplitLTOUnit(Module &M) { bool EnableSplitLTOUnit = false; if (auto *MD = mdconst::extract_or_null<ConstantInt>( M.getModuleFlag("EnableSplitLTOUnit"))) EnableSplitLTOUnit = MD->getZExtValue(); - if (!EnableSplitLTOUnit) - return false; + return EnableSplitLTOUnit; +} - // Module only needs to be split if it contains type metadata. +// Returns whether this module needs to be split because it uses type metadata. +bool hasTypeMetadata(Module &M) { for (auto &GO : M.global_objects()) { if (GO.hasMetadata(LLVMContext::MD_type)) return true; } - return false; } void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, function_ref<AAResults &(Function &)> AARGetter, Module &M, const ModuleSummaryIndex *Index) { - // Split module if splitting is enabled and it contains any type metadata. - if (requiresSplit(M)) - return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); + std::unique_ptr<ModuleSummaryIndex> NewIndex = nullptr; + // See if this module has any type metadata. If so, we try to split it + // or at least promote type ids to enable WPD. + if (hasTypeMetadata(M)) { + if (enableSplitLTOUnit(M)) + return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); + // Promote type ids as needed for index-based WPD. + std::string ModuleId = getUniqueModuleId(&M); + if (!ModuleId.empty()) { + promoteTypeIds(M, ModuleId); + // Need to rebuild the index so that it contains type metadata + // for the newly promoted type ids. + // FIXME: Probably should not bother building the index at all + // in the caller of writeThinLTOBitcode (which does so via the + // ModuleSummaryIndexAnalysis pass), since we have to rebuild it + // anyway whenever there is type metadata (here or in + // splitAndWriteThinLTOBitcode). Just always build it once via the + // buildModuleSummaryIndex when Module(s) are ready. + ProfileSummaryInfo PSI(M); + NewIndex = llvm::make_unique<ModuleSummaryIndex>( + buildModuleSummaryIndex(M, nullptr, &PSI)); + Index = NewIndex.get(); + } + } - // Otherwise we can just write it out as a regular module. + // Write it out as an unsplit ThinLTO module. // Save the module hash produced for the full bitcode, which will // be used in the backends, and use that in the minimized bitcode diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index 48bd0cda759d..6b6dd6194e17 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -1,9 +1,8 @@ //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -882,7 +881,7 @@ void DevirtModule::tryICallBranchFunnel( } BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); - Constant *Intr = + Function *Intr = Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); auto *CI = CallInst::Create(Intr, JTArgs, "", BB); @@ -921,9 +920,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, NewArgs.push_back(Int8PtrTy); for (Type *T : CS.getFunctionType()->params()) NewArgs.push_back(T); - PointerType *NewFT = PointerType::getUnqual( + FunctionType *NewFT = FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, - CS.getFunctionType()->isVarArg())); + CS.getFunctionType()->isVarArg()); + PointerType *NewFTPtr = PointerType::getUnqual(NewFT); IRBuilder<> IRB(CS.getInstruction()); std::vector<Value *> Args; @@ -933,10 +933,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, CallSite NewCS; if (CS.isCall()) - NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args); + NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); else NewCS = IRB.CreateInvoke( - IRB.CreateBitCast(JT, NewFT), + NewFT, IRB.CreateBitCast(JT, NewFTPtr), cast<InvokeInst>(CS.getInstruction())->getNormalDest(), cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args); NewCS.setCallingConv(CS.getCallingConv()); @@ -1183,7 +1183,7 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Value *Addr = B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); if (RetType->getBitWidth() == 1) { - Value *Bits = B.CreateLoad(Addr); + Value *Bits = B.CreateLoad(Int8Ty, Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, @@ -1495,8 +1495,10 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { // The type of the function in the declaration is irrelevant because every // call site will cast it to the correct type. - auto *SingleImpl = M.getOrInsertFunction( - Res.SingleImplName, Type::getVoidTy(M.getContext())); + Constant *SingleImpl = + cast<Constant>(M.getOrInsertFunction(Res.SingleImplName, + Type::getVoidTy(M.getContext())) + .getCallee()); // This is the import phase so we should not be exporting anything. bool IsExported = false; @@ -1538,8 +1540,12 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { } if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { - auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), - Type::getVoidTy(M.getContext())); + // The type of the function is irrelevant, because it's bitcast at calls + // anyhow. + Constant *JT = cast<Constant>( + M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), + Type::getVoidTy(M.getContext())) + .getCallee()); bool IsExported = false; applyICallBranchFunnel(SlotInfo, JT, IsExported); assert(!IsExported); @@ -1557,23 +1563,20 @@ void DevirtModule::removeRedundantTypeTests() { } bool DevirtModule::run() { + // If only some of the modules were split, we cannot correctly perform + // this transformation. We already checked for the presense of type tests + // with partially split modules during the thin link, and would have emitted + // an error if any were found, so here we can simply return. + if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || + (ImportSummary && ImportSummary->partiallySplitLTOUnits())) + return false; + Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); Function *TypeCheckedLoadFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - // If only some of the modules were split, we cannot correctly handle - // code that contains type tests or type checked loads. - if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || - (ImportSummary && ImportSummary->partiallySplitLTOUnits())) { - if ((TypeTestFunc && !TypeTestFunc->use_empty()) || - (TypeCheckedLoadFunc && !TypeCheckedLoadFunc->use_empty())) - report_fatal_error("inconsistent LTO Unit splitting with llvm.type.test " - "or llvm.type.checked.load"); - return false; - } - // Normally if there are no users of the devirtualization intrinsics in the // module, this pass has nothing to do. But if we are exporting, we also need // to handle any users that appear only in the function summaries. diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 6e196bfdbd25..ba15b023f2a3 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1,9 +1,8 @@ //===- InstCombineAddSub.cpp ------------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -823,6 +822,47 @@ static Value *checkForNegativeOperand(BinaryOperator &I, return nullptr; } +/// Wrapping flags may allow combining constants separated by an extend. +static Instruction *foldNoWrapAdd(BinaryOperator &Add, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Type *Ty = Add.getType(); + Constant *Op1C; + if (!match(Op1, m_Constant(Op1C))) + return nullptr; + + // Try this match first because it results in an add in the narrow type. + // (zext (X +nuw C2)) + C1 --> zext (X + (C2 + trunc(C1))) + Value *X; + const APInt *C1, *C2; + if (match(Op1, m_APInt(C1)) && + match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && + C1->isNegative() && C1->sge(-C2->sext(C1->getBitWidth()))) { + Constant *NewC = + ConstantInt::get(X->getType(), *C2 + C1->trunc(C2->getBitWidth())); + return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); + } + + // More general combining of constants in the wide type. + // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) + Constant *NarrowC; + if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) { + Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty); + Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideX = Builder.CreateSExt(X, Ty); + return BinaryOperator::CreateAdd(WideX, NewC); + } + // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C) + if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) { + Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty); + Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideX = Builder.CreateZExt(X, Ty); + return BinaryOperator::CreateAdd(WideX, NewC); + } + + return nullptr; +} + Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); Constant *Op1C; @@ -832,7 +872,14 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (Instruction *NV = foldBinOpIntoSelectOrPhi(Add)) return NV; - Value *X, *Y; + Value *X; + Constant *Op00C; + + // add (sub C1, X), C2 --> sub (add C1, C2), X + if (match(Op0, m_Sub(m_Constant(Op00C), m_Value(X)))) + return BinaryOperator::CreateSub(ConstantExpr::getAdd(Op00C, Op1C), X); + + Value *Y; // add (sub X, Y), -1 --> add (not Y), X if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y)))) && @@ -852,6 +899,11 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (!match(Op1, m_APInt(C))) return nullptr; + // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C) + const APInt *C2; + if (match(Op0, m_Or(m_Value(), m_APInt(C2))) && *C2 == -*C) + return BinaryOperator::CreateXor(Op0, ConstantInt::get(Add.getType(), *C2)); + if (C->isSignMask()) { // If wrapping is not allowed, then the addition must set the sign bit: // X + (signmask) --> X | signmask @@ -866,19 +918,10 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X Type *Ty = Add.getType(); - const APInt *C2; if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); - // (add (zext (add nuw X, C2)), C) --> (zext (add nuw X, C2 + C)) - if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && - C->isNegative() && C->sge(-C2->sext(C->getBitWidth()))) { - Constant *NewC = - ConstantInt::get(X->getType(), *C2 + C->trunc(C2->getBitWidth())); - return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); - } - if (C->isOneValue() && Op0->hasOneUse()) { // add (sext i1 X), 1 --> zext (not X) // TODO: The smallest IR representation is (select X, 0, 1), and that would @@ -1032,6 +1075,28 @@ static Instruction *canonicalizeLowbitMask(BinaryOperator &I, return BinaryOperator::CreateNot(NotMask, I.getName()); } +static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::Add && "Expecting add instruction"); + Type *Ty = I.getType(); + auto getUAddSat = [&]() { + return Intrinsic::getDeclaration(I.getModule(), Intrinsic::uadd_sat, Ty); + }; + + // add (umin X, ~Y), Y --> uaddsat X, Y + Value *X, *Y; + if (match(&I, m_c_Add(m_c_UMin(m_Value(X), m_Not(m_Value(Y))), + m_Deferred(Y)))) + return CallInst::Create(getUAddSat(), { X, Y }); + + // add (umin X, ~C), C --> uaddsat X, C + const APInt *C, *NotC; + if (match(&I, m_Add(m_UMin(m_Value(X), m_APInt(NotC)), m_APInt(C))) && + *C == ~*NotC) + return CallInst::Create(getUAddSat(), { X, ConstantInt::get(Ty, *C) }); + + return nullptr; +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1051,6 +1116,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *X = foldAddWithConstant(I)) return X; + if (Instruction *X = foldNoWrapAdd(I, Builder)) + return X; + // FIXME: This should be moved into the above helper function to allow these // transforms for general constant or constant splat vectors. Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); @@ -1119,6 +1187,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(RHS, A); } + // Canonicalize sext to zext for better value tracking potential. + // add A, sext(B) --> sub A, zext(B) + if (match(&I, m_c_Add(m_Value(A), m_OneUse(m_SExt(m_Value(B))))) && + B->getType()->isIntOrIntVectorTy(1)) + return BinaryOperator::CreateSub(A, Builder.CreateZExt(B, Ty)); + // A + -B --> A - B if (match(RHS, m_Neg(m_Value(B)))) return BinaryOperator::CreateSub(LHS, B); @@ -1128,7 +1202,10 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (A + 1) + ~B --> A - B // ~B + (A + 1) --> A - B - if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B))))) + // (~B + A) + 1 --> A - B + // (A + ~B) + 1 --> A - B + if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B)))) || + match(&I, m_BinOp(m_c_Add(m_Not(m_Value(B)), m_Value(A)), m_One()))) return BinaryOperator::CreateSub(A, B); // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) @@ -1225,6 +1302,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *V = canonicalizeLowbitMask(I, Builder)) return V; + if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) + return SatAdd; + return Changed ? &I : nullptr; } @@ -1500,6 +1580,12 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One())))) return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0); + // Y - ~X --> (X + 1) + Y + if (match(Op1, m_OneUse(m_Not(m_Value(X))))) { + return BinaryOperator::CreateAdd( + Builder.CreateAdd(Op0, ConstantInt::get(I.getType(), 1)), X); + } + if (Constant *C = dyn_cast<Constant>(Op0)) { bool IsNegate = match(C, m_ZeroInt()); Value *X; @@ -1532,8 +1618,13 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Instruction *R = foldOpIntoPhi(I, PN)) return R; - // C-(X+C2) --> (C-C2)-X Constant *C2; + + // C-(C2-X) --> X+(C-C2) + if (match(Op1, m_Sub(m_Constant(C2), m_Value(X)))) + return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); + + // C-(X+C2) --> (C-C2)-X if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); } @@ -1626,9 +1717,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Builder.CreateNot(Y, Y->getName() + ".not")); // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. - if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) && - C->isNotMinSignedValue() && !C->isOneValue()) - return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); + // TODO: This could be extended to match arbitrary vector constants. + const APInt *DivC; + if (match(Op0, m_Zero()) && match(Op1, m_SDiv(m_Value(X), m_APInt(DivC))) && + !DivC->isMinSignedValue() && *DivC != 1) { + Constant *NegDivC = ConstantInt::get(I.getType(), -(*DivC)); + Instruction *BO = BinaryOperator::CreateSDiv(X, NegDivC); + BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); + return BO; + } // 0 - (X << Y) -> (-X << Y) when X is freely negatable. if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero())) @@ -1745,6 +1842,49 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Changed ? &I : nullptr; } +/// This eliminates floating-point negation in either 'fneg(X)' or +/// 'fsub(-0.0, X)' form by combining into a constant operand. +static Instruction *foldFNegIntoConstant(Instruction &I) { + Value *X; + Constant *C; + + // Fold negation into constant operand. This is limited with one-use because + // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. + // -(X * C) --> X * (-C) + // FIXME: It's arguable whether these should be m_OneUse or not. The current + // belief is that the FNeg allows for better reassociation opportunities. + if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X / C) --> X / (-C) + if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + // -(C / X) --> (-C) / X + if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + + return nullptr; +} + +Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { + Value *Op = I.getOperand(0); + + if (Value *V = SimplifyFNegInst(Op, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (Instruction *X = foldFNegIntoConstant(I)) + return X; + + Value *X, *Y; + + // If we can ignore the sign of zeros: -(X - Y) --> (Y - X) + if (I.hasNoSignedZeros() && + match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + return nullptr; +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -1760,21 +1900,12 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) return BinaryOperator::CreateFNegFMF(Op1, &I); + if (Instruction *X = foldFNegIntoConstant(I)) + return X; + Value *X, *Y; Constant *C; - // Fold negation into constant operand. This is limited with one-use because - // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. - // -(X * C) --> X * (-C) - if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - // -(X / C) --> X / (-C) - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); - // -(C / X) --> (-C) / X - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); - // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 404c2ad7e6e7..2b9859b602f4 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1,9 +1,8 @@ //===- InstCombineAndOrXor.cpp --------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -968,7 +967,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, // Can it be decomposed into icmp eq (X & Mask), 0 ? if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), Pred, X, UnsetBitsMask, - /*LookThruTrunc=*/false) && + /*LookThroughTrunc=*/false) && Pred == ICmpInst::ICMP_EQ) return true; // Is it icmp eq (X & Mask), 0 already? @@ -1022,6 +1021,36 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, CxtI.getName() + ".simplified"); } +/// Reduce a pair of compares that check if a value has exactly 1 bit set. +static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, + InstCombiner::BuilderTy &Builder) { + // Handle 'and' / 'or' commutation: make the equality check the first operand. + if (JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_NE) + std::swap(Cmp0, Cmp1); + else if (!JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_EQ) + std::swap(Cmp0, Cmp1); + + // (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1 + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && + match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), + m_SpecificInt(2))) && + Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) { + Value *CtPop = Cmp1->getOperand(0); + return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1)); + } + // (X == 0) || (ctpop(X) u> 1) --> ctpop(X) != 1 + if (!JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) && + match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), + m_SpecificInt(1))) && + Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_UGT) { + Value *CtPop = Cmp1->getOperand(0); + return Builder.CreateICmpNE(CtPop, ConstantInt::get(CtPop->getType(), 1)); + } + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { @@ -1064,6 +1093,9 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) return V; + if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -1259,6 +1291,52 @@ Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) return nullptr; } +/// This a limited reassociation for a special case (see above) where we are +/// checking if two values are either both NAN (unordered) or not-NAN (ordered). +/// This could be handled more generally in '-reassociation', but it seems like +/// an unlikely pattern for a large number of logic ops and fcmps. +static Instruction *reassociateFCmps(BinaryOperator &BO, + InstCombiner::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = BO.getOpcode(); + assert((Opcode == Instruction::And || Opcode == Instruction::Or) && + "Expecting and/or op for fcmp transform"); + + // There are 4 commuted variants of the pattern. Canonicalize operands of this + // logic op so an fcmp is operand 0 and a matching logic op is operand 1. + Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1), *X; + FCmpInst::Predicate Pred; + if (match(Op1, m_FCmp(Pred, m_Value(), m_AnyZeroFP()))) + std::swap(Op0, Op1); + + // Match inner binop and the predicate for combining 2 NAN checks into 1. + BinaryOperator *BO1; + FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD + : FCmpInst::FCMP_UNO; + if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred || + !match(Op1, m_BinOp(BO1)) || BO1->getOpcode() != Opcode) + return nullptr; + + // The inner logic op must have a matching fcmp operand. + Value *BO10 = BO1->getOperand(0), *BO11 = BO1->getOperand(1), *Y; + if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || + Pred != NanPred || X->getType() != Y->getType()) + std::swap(BO10, BO11); + + if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || + Pred != NanPred || X->getType() != Y->getType()) + return nullptr; + + // and (fcmp ord X, 0), (and (fcmp ord Y, 0), Z) --> and (fcmp ord X, Y), Z + // or (fcmp uno X, 0), (or (fcmp uno Y, 0), Z) --> or (fcmp uno X, Y), Z + Value *NewFCmp = Builder.CreateFCmp(Pred, X, Y); + if (auto *NewFCmpInst = dyn_cast<FCmpInst>(NewFCmp)) { + // Intersect FMF from the 2 source fcmps. + NewFCmpInst->copyIRFlags(Op0); + NewFCmpInst->andIRFlags(BO10); + } + return BinaryOperator::Create(Opcode, NewFCmp, BO11); +} + /// Match De Morgan's Laws: /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) @@ -1619,6 +1697,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth // of X and OP behaves well when given trunc(C1) and X. + // TODO: Do this for vectors by using m_APInt isntead of m_ConstantInt. switch (Op0I->getOpcode()) { default: break; @@ -1629,7 +1708,10 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { case Instruction::Sub: Value *X; ConstantInt *C1; - if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) { + // TODO: The one use restrictions could be relaxed a little if the AND + // is going to be removed. + if (match(Op0I, m_OneUse(m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), + m_ConstantInt(C1))))) { if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); Value *BinOp; @@ -1747,6 +1829,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *Res = foldLogicOfFCmps(LHS, RHS, true)) return replaceInstUsesWith(I, Res); + if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) + return FoldedFCmps; + if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) return CastedAnd; @@ -1820,14 +1905,18 @@ static Instruction *matchRotate(Instruction &Or) { // First, find an or'd pair of opposite shifts with the same shifted operand: // or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1) - Value *Or0 = Or.getOperand(0), *Or1 = Or.getOperand(1); + BinaryOperator *Or0, *Or1; + if (!match(Or.getOperand(0), m_BinOp(Or0)) || + !match(Or.getOperand(1), m_BinOp(Or1))) + return nullptr; + Value *ShVal, *ShAmt0, *ShAmt1; if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) return nullptr; - auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); - auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + BinaryOperator::BinaryOps ShiftOpcode0 = Or0->getOpcode(); + BinaryOperator::BinaryOps ShiftOpcode1 = Or1->getOpcode(); if (ShiftOpcode0 == ShiftOpcode1) return nullptr; @@ -1842,6 +1931,13 @@ static Instruction *matchRotate(Instruction &Or) { match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) return X; + // Similar to above, but the shift amount may be extended after masking, + // so return the extended value as the parameter for the intrinsic. + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), + m_SpecificInt(Mask)))) + return L; + return nullptr; }; @@ -2083,6 +2179,9 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) return V; + if (Value *V = foldIsPowerOf2(LHS, RHS, false /* JoinedByAnd */, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -2412,6 +2511,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *Res = foldLogicOfFCmps(LHS, RHS, false)) return replaceInstUsesWith(I, Res); + if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) + return FoldedFCmps; + if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) return CastedOr; diff --git a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp new file mode 100644 index 000000000000..5f37a00f56cf --- /dev/null +++ b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -0,0 +1,159 @@ +//===- InstCombineAtomicRMW.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the visit functions for atomic rmw instructions. +// +//===----------------------------------------------------------------------===// +#include "InstCombineInternal.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +namespace { +/// Return true if and only if the given instruction does not modify the memory +/// location referenced. Note that an idemptent atomicrmw may still have +/// ordering effects on nearby instructions, or be volatile. +/// TODO: Common w/ the version in AtomicExpandPass, and change the term used. +/// Idemptotent is confusing in this context. +bool isIdempotentRMW(AtomicRMWInst& RMWI) { + if (auto CF = dyn_cast<ConstantFP>(RMWI.getValOperand())) + switch(RMWI.getOperation()) { + case AtomicRMWInst::FAdd: // -0.0 + return CF->isZero() && CF->isNegative(); + case AtomicRMWInst::FSub: // +0.0 + return CF->isZero() && !CF->isNegative(); + default: + return false; + }; + + auto C = dyn_cast<ConstantInt>(RMWI.getValOperand()); + if(!C) + return false; + + switch(RMWI.getOperation()) { + case AtomicRMWInst::Add: + case AtomicRMWInst::Sub: + case AtomicRMWInst::Or: + case AtomicRMWInst::Xor: + return C->isZero(); + case AtomicRMWInst::And: + return C->isMinusOne(); + case AtomicRMWInst::Min: + return C->isMaxValue(true); + case AtomicRMWInst::Max: + return C->isMinValue(true); + case AtomicRMWInst::UMin: + return C->isMaxValue(false); + case AtomicRMWInst::UMax: + return C->isMinValue(false); + default: + return false; + } +} + +/// Return true if the given instruction always produces a value in memory +/// equivalent to its value operand. +bool isSaturating(AtomicRMWInst& RMWI) { + if (auto CF = dyn_cast<ConstantFP>(RMWI.getValOperand())) + switch(RMWI.getOperation()) { + case AtomicRMWInst::FAdd: + case AtomicRMWInst::FSub: + return CF->isNaN(); + default: + return false; + }; + + auto C = dyn_cast<ConstantInt>(RMWI.getValOperand()); + if(!C) + return false; + + switch(RMWI.getOperation()) { + default: + return false; + case AtomicRMWInst::Xchg: + return true; + case AtomicRMWInst::Or: + return C->isAllOnesValue(); + case AtomicRMWInst::And: + return C->isZero(); + case AtomicRMWInst::Min: + return C->isMinValue(true); + case AtomicRMWInst::Max: + return C->isMaxValue(true); + case AtomicRMWInst::UMin: + return C->isMinValue(false); + case AtomicRMWInst::UMax: + return C->isMaxValue(false); + }; +} +} + +Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { + + // Volatile RMWs perform a load and a store, we cannot replace this by just a + // load or just a store. We chose not to canonicalize out of general paranoia + // about user expectations around volatile. + if (RMWI.isVolatile()) + return nullptr; + + // Any atomicrmw op which produces a known result in memory can be + // replaced w/an atomicrmw xchg. + if (isSaturating(RMWI) && + RMWI.getOperation() != AtomicRMWInst::Xchg) { + RMWI.setOperation(AtomicRMWInst::Xchg); + return &RMWI; + } + + AtomicOrdering Ordering = RMWI.getOrdering(); + assert(Ordering != AtomicOrdering::NotAtomic && + Ordering != AtomicOrdering::Unordered && + "AtomicRMWs don't make sense with Unordered or NotAtomic"); + + // Any atomicrmw xchg with no uses can be converted to a atomic store if the + // ordering is compatible. + if (RMWI.getOperation() == AtomicRMWInst::Xchg && + RMWI.use_empty()) { + if (Ordering != AtomicOrdering::Release && + Ordering != AtomicOrdering::Monotonic) + return nullptr; + auto *SI = new StoreInst(RMWI.getValOperand(), + RMWI.getPointerOperand(), &RMWI); + SI->setAtomic(Ordering, RMWI.getSyncScopeID()); + SI->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + return eraseInstFromFunction(RMWI); + } + + if (!isIdempotentRMW(RMWI)) + return nullptr; + + // We chose to canonicalize all idempotent operations to an single + // operation code and constant. This makes it easier for the rest of the + // optimizer to match easily. The choices of or w/0 and fadd w/-0.0 are + // arbitrary. + if (RMWI.getType()->isIntegerTy() && + RMWI.getOperation() != AtomicRMWInst::Or) { + RMWI.setOperation(AtomicRMWInst::Or); + RMWI.setOperand(1, ConstantInt::get(RMWI.getType(), 0)); + return &RMWI; + } else if (RMWI.getType()->isFloatingPointTy() && + RMWI.getOperation() != AtomicRMWInst::FAdd) { + RMWI.setOperation(AtomicRMWInst::FAdd); + RMWI.setOperand(1, ConstantFP::getNegativeZero(RMWI.getType())); + return &RMWI; + } + + // Check if the required ordering is compatible with an atomic load. + if (Ordering != AtomicOrdering::Acquire && + Ordering != AtomicOrdering::Monotonic) + return nullptr; + + LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand()); + Load->setAtomic(Ordering, RMWI.getSyncScopeID()); + Load->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + return Load; +} diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index aeb25d530d71..4b3333affa72 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1,19 +1,19 @@ //===- InstCombineCalls.cpp -----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // -// This file implements the visitCall and visitInvoke functions. +// This file implements the visitCall, visitInvoke, and visitCallBr functions. // //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" @@ -23,12 +23,12 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -58,6 +58,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> #include <cassert> @@ -121,6 +122,15 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { return MI; } + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(MI->getDest())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with // load/store. ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getLength()); @@ -173,7 +183,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); - LoadInst *L = Builder.CreateLoad(Src); + LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. L->setAlignment(CopySrcAlign); if (CopyMD) @@ -219,6 +229,15 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { return MI; } + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(MI->getDest())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + // Extract the length and alignment and fill if they are constant. ConstantInt *LenC = dyn_cast<ConstantInt>(MI->getLength()); ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); @@ -523,7 +542,8 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { +static Value *simplifyX86pack(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder, bool IsSigned) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); Type *ResTy = II.getType(); @@ -534,167 +554,58 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { Type *ArgTy = Arg0->getType(); unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; - unsigned NumDstElts = ResTy->getVectorNumElements(); unsigned NumSrcElts = ArgTy->getVectorNumElements(); - assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + assert(ResTy->getVectorNumElements() == (2 * NumSrcElts) && + "Unexpected packing types"); - unsigned NumDstEltsPerLane = NumDstElts / NumLanes; unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); - assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && + unsigned SrcScalarSizeInBits = ArgTy->getScalarSizeInBits(); + assert(SrcScalarSizeInBits == (2 * DstScalarSizeInBits) && "Unexpected packing types"); // Constant folding. - auto *Cst0 = dyn_cast<Constant>(Arg0); - auto *Cst1 = dyn_cast<Constant>(Arg1); - if (!Cst0 || !Cst1) + if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) return nullptr; - SmallVector<Constant *, 32> Vals; - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { - unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; - auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; - auto *COp = Cst->getAggregateElement(SrcIdx); - if (COp && isa<UndefValue>(COp)) { - Vals.push_back(UndefValue::get(ResTy->getScalarType())); - continue; - } - - auto *CInt = dyn_cast_or_null<ConstantInt>(COp); - if (!CInt) - return nullptr; - - APInt Val = CInt->getValue(); - assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && - "Unexpected constant bitwidth"); - - if (IsSigned) { - // PACKSS: Truncate signed value with signed saturation. - // Source values less than dst minint are saturated to minint. - // Source values greater than dst maxint are saturated to maxint. - if (Val.isSignedIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getSignedMinValue(DstScalarSizeInBits); - else - Val = APInt::getSignedMaxValue(DstScalarSizeInBits); - } else { - // PACKUS: Truncate signed value with unsigned saturation. - // Source values less than zero are saturated to zero. - // Source values greater than dst maxuint are saturated to maxuint. - if (Val.isIntN(DstScalarSizeInBits)) - Val = Val.trunc(DstScalarSizeInBits); - else if (Val.isNegative()) - Val = APInt::getNullValue(DstScalarSizeInBits); - else - Val = APInt::getAllOnesValue(DstScalarSizeInBits); - } - - Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); - } - } - - return ConstantVector::get(Vals); -} - -// Replace X86-specific intrinsics with generic floor-ceil where applicable. -static Value *simplifyX86round(IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - ConstantInt *Arg = nullptr; - Intrinsic::ID IntrinsicID = II.getIntrinsicID(); - - if (IntrinsicID == Intrinsic::x86_sse41_round_ss || - IntrinsicID == Intrinsic::x86_sse41_round_sd) - Arg = dyn_cast<ConstantInt>(II.getArgOperand(2)); - else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) - Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); - else - Arg = dyn_cast<ConstantInt>(II.getArgOperand(1)); - if (!Arg) - return nullptr; - unsigned RoundControl = Arg->getZExtValue(); - - Arg = nullptr; - unsigned SAE = 0; - if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) - Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); - else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) - Arg = dyn_cast<ConstantInt>(II.getArgOperand(5)); - else - SAE = 4; - if (!SAE) { - if (!Arg) - return nullptr; - SAE = Arg->getZExtValue(); + // Clamp Values - signed/unsigned both use signed clamp values, but they + // differ on the min/max values. + APInt MinValue, MaxValue; + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + MinValue = + APInt::getSignedMinValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + MaxValue = + APInt::getSignedMaxValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + MinValue = APInt::getNullValue(SrcScalarSizeInBits); + MaxValue = APInt::getLowBitsSet(SrcScalarSizeInBits, DstScalarSizeInBits); } - if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/)) - return nullptr; + auto *MinC = Constant::getIntegerValue(ArgTy, MinValue); + auto *MaxC = Constant::getIntegerValue(ArgTy, MaxValue); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg0, MinC), MinC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg1, MinC), MinC, Arg1); + Arg0 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg0, MaxC), MaxC, Arg0); + Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1); - Value *Src, *Dst, *Mask; - bool IsScalar = false; - if (IntrinsicID == Intrinsic::x86_sse41_round_ss || - IntrinsicID == Intrinsic::x86_sse41_round_sd || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { - IsScalar = true; - if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { - Mask = II.getArgOperand(3); - Value *Zero = Constant::getNullValue(Mask->getType()); - Mask = Builder.CreateAnd(Mask, 1); - Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero); - Dst = II.getArgOperand(2); - } else - Dst = II.getArgOperand(0); - Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0); - } else { - Src = II.getArgOperand(0); - if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) { - Dst = II.getArgOperand(2); - Mask = II.getArgOperand(3); - } else { - Dst = Src; - Mask = ConstantInt::getAllOnesValue( - Builder.getIntNTy(Src->getType()->getVectorNumElements())); - } + // Shuffle clamped args together at the lane level. + SmallVector<unsigned, 32> PackMask; + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane)); + for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) + PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane) + NumSrcElts); } + auto *Shuffle = Builder.CreateShuffleVector(Arg0, Arg1, PackMask); - Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor; - Value *Res = Builder.CreateUnaryIntrinsic(ID, Src, &II); - if (!IsScalar) { - if (auto *C = dyn_cast<Constant>(Mask)) - if (C->isAllOnesValue()) - return Res; - auto *MaskTy = VectorType::get( - Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth()); - Mask = Builder.CreateBitCast(Mask, MaskTy); - unsigned Width = Src->getType()->getVectorNumElements(); - if (MaskTy->getVectorNumElements() > Width) { - uint32_t Indices[4]; - for (unsigned i = 0; i != Width; ++i) - Indices[i] = i; - Mask = Builder.CreateShuffleVector(Mask, Mask, - makeArrayRef(Indices, Width)); - } - return Builder.CreateSelect(Mask, Res, Dst); - } - if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || - IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { - Dst = Builder.CreateExtractElement(Dst, (uint64_t)0); - Res = Builder.CreateSelect(Mask, Res, Dst); - Dst = II.getArgOperand(0); - } - return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); + // Truncate to dst size. + return Builder.CreateTrunc(Shuffle, ResTy); } static Value *simplifyX86movmsk(const IntrinsicInst &II, @@ -711,43 +622,44 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II, if (!ArgTy->isVectorTy()) return nullptr; - if (auto *C = dyn_cast<Constant>(Arg)) { - // Extract signbits of the vector input and pack into integer result. - APInt Result(ResTy->getPrimitiveSizeInBits(), 0); - for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { - auto *COp = C->getAggregateElement(I); - if (!COp) - return nullptr; - if (isa<UndefValue>(COp)) - continue; - - auto *CInt = dyn_cast<ConstantInt>(COp); - auto *CFp = dyn_cast<ConstantFP>(COp); - if (!CInt && !CFp) - return nullptr; - - if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) - Result.setBit(I); - } - return Constant::getIntegerValue(ResTy, Result); - } + // Expand MOVMSK to compare/bitcast/zext: + // e.g. PMOVMSKB(v16i8 x): + // %cmp = icmp slt <16 x i8> %x, zeroinitializer + // %int = bitcast <16 x i1> %cmp to i16 + // %res = zext i16 %int to i32 + unsigned NumElts = ArgTy->getVectorNumElements(); + Type *IntegerVecTy = VectorType::getInteger(cast<VectorType>(ArgTy)); + Type *IntegerTy = Builder.getIntNTy(NumElts); + + Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy); + Res = Builder.CreateICmpSLT(Res, Constant::getNullValue(IntegerVecTy)); + Res = Builder.CreateBitCast(Res, IntegerTy); + Res = Builder.CreateZExtOrTrunc(Res, ResTy); + return Res; +} - // Look for a sign-extended boolean source vector as the argument to this - // movmsk. If the argument is bitcast, look through that, but make sure the - // source of that bitcast is still a vector with the same number of elements. - // TODO: We can also convert a bitcast with wider elements, but that requires - // duplicating the bool source sign bits to match the number of elements - // expected by the movmsk call. - Arg = peekThroughBitcast(Arg); - Value *X; - if (Arg->getType()->isVectorTy() && - Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() && - match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM - unsigned NumElts = X->getType()->getVectorNumElements(); - Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts); - Value *BC = Builder.CreateBitCast(X, ScalarTy); - return Builder.CreateZExtOrTrunc(BC, ResTy); +static Value *simplifyX86addcarry(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Value *CarryIn = II.getArgOperand(0); + Value *Op1 = II.getArgOperand(1); + Value *Op2 = II.getArgOperand(2); + Type *RetTy = II.getType(); + Type *OpTy = Op1->getType(); + assert(RetTy->getStructElementType(0)->isIntegerTy(8) && + RetTy->getStructElementType(1) == OpTy && OpTy == Op2->getType() && + "Unexpected types for x86 addcarry"); + + // If carry-in is zero, this is just an unsigned add with overflow. + if (match(CarryIn, m_ZeroInt())) { + Value *UAdd = Builder.CreateIntrinsic(Intrinsic::uadd_with_overflow, OpTy, + { Op1, Op2 }); + // The types have to be adjusted to match the x86 call types. + Value *UAddResult = Builder.CreateExtractValue(UAdd, 0); + Value *UAddOV = Builder.CreateZExt(Builder.CreateExtractValue(UAdd, 1), + Builder.getInt8Ty()); + Value *Res = UndefValue::get(RetTy); + Res = Builder.CreateInsertValue(Res, UAddOV, 0); + return Builder.CreateInsertValue(Res, UAddResult, 1); } return nullptr; @@ -892,7 +804,7 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, if (II.getIntrinsicID() == Intrinsic::x86_sse4a_extrq) { Value *Args[] = {Op0, CILength, CIIndex}; Module *M = II.getModule(); - Value *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_extrqi); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_extrqi); return Builder.CreateCall(F, Args); } } @@ -993,7 +905,7 @@ static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, Value *Args[] = {Op0, Op1, CILength, CIIndex}; Module *M = II.getModule(); - Value *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_insertqi); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_insertqi); return Builder.CreateCall(F, Args); } @@ -1134,82 +1046,42 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, return Builder.CreateShuffleVector(V1, V2, ShuffleMask); } -/// Decode XOP integer vector comparison intrinsics. -static Value *simplifyX86vpcom(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, - bool IsSigned) { - if (auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2))) { - uint64_t Imm = CInt->getZExtValue() & 0x7; - VectorType *VecTy = cast<VectorType>(II.getType()); - CmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; - - switch (Imm) { - case 0x0: - Pred = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - break; - case 0x1: - Pred = IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; - break; - case 0x2: - Pred = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - break; - case 0x3: - Pred = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - break; - case 0x4: - Pred = ICmpInst::ICMP_EQ; break; - case 0x5: - Pred = ICmpInst::ICMP_NE; break; - case 0x6: - return ConstantInt::getSigned(VecTy, 0); // FALSE - case 0x7: - return ConstantInt::getSigned(VecTy, -1); // TRUE - } - - if (Value *Cmp = Builder.CreateICmp(Pred, II.getArgOperand(0), - II.getArgOperand(1))) - return Builder.CreateSExtOrTrunc(Cmp, VecTy); - } - return nullptr; -} +// TODO, Obvious Missing Transforms: +// * Narrow width by halfs excluding zero/undef lanes +Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { + Value *LoadPtr = II.getArgOperand(0); + unsigned Alignment = cast<ConstantInt>(II.getArgOperand(1))->getZExtValue(); -static bool maskIsAllOneOrUndef(Value *Mask) { - auto *ConstMask = dyn_cast<Constant>(Mask); - if (!ConstMask) - return false; - if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask)) - return true; - for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; - ++I) { - if (auto *MaskElt = ConstMask->getAggregateElement(I)) - if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt)) - continue; - return false; - } - return true; -} - -static Value *simplifyMaskedLoad(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. - if (maskIsAllOneOrUndef(II.getArgOperand(2))) { - Value *LoadPtr = II.getArgOperand(0); - unsigned Alignment = cast<ConstantInt>(II.getArgOperand(1))->getZExtValue(); - return Builder.CreateAlignedLoad(LoadPtr, Alignment, "unmaskedload"); + if (maskIsAllOneOrUndef(II.getArgOperand(2))) + return Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, + "unmaskedload"); + + // If we can unconditionally load from this address, replace with a + // load/select idiom. TODO: use DT for context sensitive query + if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment, + II.getModule()->getDataLayout(), + &II, nullptr)) { + Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, + "unmaskedload"); + return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); } return nullptr; } -static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { +// TODO, Obvious Missing Transforms: +// * Single constant active lane -> store +// * Narrow width by halfs excluding zero/undef lanes +Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); if (!ConstMask) return nullptr; // If the mask is all zeros, this instruction does nothing. if (ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); + return eraseInstFromFunction(II); // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { @@ -1218,14 +1090,57 @@ static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) { return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + return nullptr; } -static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { - // If the mask is all zeros, return the "passthru" argument of the gather. - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); - if (ConstMask && ConstMask->isNullValue()) - return IC.replaceInstUsesWith(II, II.getArgOperand(3)); +// TODO, Obvious Missing Transforms: +// * Single constant active lane load -> load +// * Dereferenceable address & few lanes -> scalarize speculative load/selects +// * Adjacent vector addresses -> masked.load +// * Narrow width by halfs excluding zero/undef lanes +// * Vector splat address w/known mask -> scalar load +// * Vector incrementing address -> vector masked load +Instruction *InstCombiner::simplifyMaskedGather(IntrinsicInst &II) { + return nullptr; +} + +// TODO, Obvious Missing Transforms: +// * Single constant active lane -> store +// * Adjacent vector addresses -> masked.store +// * Narrow store width by halfs excluding zero/undef lanes +// * Vector splat address w/known mask -> scalar store +// * Vector incrementing address -> vector masked store +Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + if (!ConstMask) + return nullptr; + + // If the mask is all zeros, a scatter does nothing. + if (ConstMask->isNullValue()) + return eraseInstFromFunction(II); + + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts + APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); + APInt UndefElts(DemandedElts.getBitWidth(), 0); + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), + DemandedElts, UndefElts)) { + II.setOperand(0, V); + return &II; + } + if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), + DemandedElts, UndefElts)) { + II.setOperand(1, V); + return &II; + } return nullptr; } @@ -1264,25 +1179,41 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, return cast<Instruction>(Result); } -static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { - // If the mask is all zeros, a scatter does nothing. - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); - if (ConstMask && ConstMask->isNullValue()) - return IC.eraseInstFromFunction(II); - - return nullptr; -} - static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { assert((II.getIntrinsicID() == Intrinsic::cttz || II.getIntrinsicID() == Intrinsic::ctlz) && "Expected cttz or ctlz intrinsic"); + bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; Value *Op0 = II.getArgOperand(0); + Value *X; + // ctlz(bitreverse(x)) -> cttz(x) + // cttz(bitreverse(x)) -> ctlz(x) + if (match(Op0, m_BitReverse(m_Value(X)))) { + Intrinsic::ID ID = IsTZ ? Intrinsic::ctlz : Intrinsic::cttz; + Function *F = Intrinsic::getDeclaration(II.getModule(), ID, II.getType()); + return CallInst::Create(F, {X, II.getArgOperand(1)}); + } + + if (IsTZ) { + // cttz(-x) -> cttz(x) + if (match(Op0, m_Neg(m_Value(X)))) { + II.setOperand(0, X); + return &II; + } + + // cttz(abs(x)) -> cttz(x) + // cttz(nabs(x)) -> cttz(x) + Value *Y; + SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) { + II.setOperand(0, X); + return &II; + } + } KnownBits Known = IC.computeKnownBits(Op0, 0, &II); // Create a mask for bits above (ctlz) or below (cttz) the first known one. - bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; unsigned PossibleZeros = IsTZ ? Known.countMaxTrailingZeros() : Known.countMaxLeadingZeros(); unsigned DefiniteZeros = IsTZ ? Known.countMinTrailingZeros() @@ -1328,6 +1259,14 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) { assert(II.getIntrinsicID() == Intrinsic::ctpop && "Expected ctpop intrinsic"); Value *Op0 = II.getArgOperand(0); + Value *X; + // ctpop(bitreverse(x)) -> ctpop(x) + // ctpop(bswap(x)) -> ctpop(x) + if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X)))) { + II.setOperand(0, X); + return &II; + } + // FIXME: Try to simplify vectors of integers. auto *IT = dyn_cast<IntegerType>(Op0->getType()); if (!IT) @@ -1513,7 +1452,7 @@ static Value *simplifyNeonVld1(const IntrinsicInst &II, auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0), PointerType::get(II.getType(), 0)); - return Builder.CreateAlignedLoad(BCastInst, Alignment); + return Builder.CreateAlignedLoad(II.getType(), BCastInst, Alignment); } // Returns true iff the 2 intrinsics have the same operands, limiting the @@ -1827,8 +1766,18 @@ static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) { return nullptr; } +Instruction *InstCombiner::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { + WithOverflowInst *WO = cast<WithOverflowInst>(II); + Value *OperationResult = nullptr; + Constant *OverflowResult = nullptr; + if (OptimizeOverflowCheck(WO->getBinaryOp(), WO->isSigned(), WO->getLHS(), + WO->getRHS(), *WO, OperationResult, OverflowResult)) + return CreateOverflowTuple(WO, OperationResult, OverflowResult); + return nullptr; +} + /// CallInst simplification. This mostly only handles folding of intrinsic -/// instructions. For normal calls, it allows visitCallSite to do the heavy +/// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Value *V = SimplifyCall(&CI, SQ.getWithInstruction(&CI))) @@ -1845,10 +1794,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI); - if (!II) return visitCallSite(&CI); + if (!II) return visitCallBase(CI); - // Intrinsics cannot occur in an invoke, so handle them here instead of in - // visitCallSite. + // Intrinsics cannot occur in an invoke or a callbr, so handle them here + // instead of in visitCallBase. if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) { bool Changed = false; @@ -1908,6 +1857,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } + // For vector result intrinsics, use the generic demanded vector support. + if (II->getType()->isVectorTy()) { + auto VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + } + if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) return I; @@ -1918,12 +1879,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); }; - switch (II->getIntrinsicID()) { + Intrinsic::ID IID = II->getIntrinsicID(); + switch (IID) { default: break; case Intrinsic::objectsize: - if (ConstantInt *N = - lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) - return replaceInstUsesWith(CI, N); + if (Value *V = lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) + return replaceInstUsesWith(CI, V); return nullptr; case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); @@ -1940,15 +1901,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::masked_load: - if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II, Builder)) + if (Value *SimplifiedMaskedOp = simplifyMaskedLoad(*II)) return replaceInstUsesWith(CI, SimplifiedMaskedOp); break; case Intrinsic::masked_store: - return simplifyMaskedStore(*II, *this); + return simplifyMaskedStore(*II); case Intrinsic::masked_gather: - return simplifyMaskedGather(*II, *this); + return simplifyMaskedGather(*II); case Intrinsic::masked_scatter: - return simplifyMaskedScatter(*II, *this); + return simplifyMaskedScatter(*II); case Intrinsic::launder_invariant_group: case Intrinsic::strip_invariant_group: if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) @@ -1982,33 +1943,62 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::fshl: case Intrinsic::fshr: { - const APInt *SA; - if (match(II->getArgOperand(2), m_APInt(SA))) { - Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); - unsigned BitWidth = SA->getBitWidth(); - uint64_t ShiftAmt = SA->urem(BitWidth); - assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift"); - // Normalize to funnel shift left. - if (II->getIntrinsicID() == Intrinsic::fshr) - ShiftAmt = BitWidth - ShiftAmt; - - // fshl(X, 0, C) -> shl X, C - // fshl(X, undef, C) -> shl X, C - if (match(Op1, m_Zero()) || match(Op1, m_Undef())) - return BinaryOperator::CreateShl( - Op0, ConstantInt::get(II->getType(), ShiftAmt)); - - // fshl(0, X, C) -> lshr X, (BW-C) - // fshl(undef, X, C) -> lshr X, (BW-C) - if (match(Op0, m_Zero()) || match(Op0, m_Undef())) - return BinaryOperator::CreateLShr( - Op1, ConstantInt::get(II->getType(), BitWidth - ShiftAmt)); + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + Type *Ty = II->getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + Constant *ShAmtC; + if (match(II->getArgOperand(2), m_Constant(ShAmtC)) && + !isa<ConstantExpr>(ShAmtC) && !ShAmtC->containsConstantExpression()) { + // Canonicalize a shift amount constant operand to modulo the bit-width. + Constant *WidthC = ConstantInt::get(Ty, BitWidth); + Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); + if (ModuloC != ShAmtC) { + II->setArgOperand(2, ModuloC); + return II; + } + assert(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC) == + ConstantInt::getTrue(CmpInst::makeCmpResultType(Ty)) && + "Shift amount expected to be modulo bitwidth"); + + // Canonicalize funnel shift right by constant to funnel shift left. This + // is not entirely arbitrary. For historical reasons, the backend may + // recognize rotate left patterns but miss rotate right patterns. + if (IID == Intrinsic::fshr) { + // fshr X, Y, C --> fshl X, Y, (BitWidth - C) + Constant *LeftShiftC = ConstantExpr::getSub(WidthC, ShAmtC); + Module *Mod = II->getModule(); + Function *Fshl = Intrinsic::getDeclaration(Mod, Intrinsic::fshl, Ty); + return CallInst::Create(Fshl, { Op0, Op1, LeftShiftC }); + } + assert(IID == Intrinsic::fshl && + "All funnel shifts by simple constants should go left"); + + // fshl(X, 0, C) --> shl X, C + // fshl(X, undef, C) --> shl X, C + if (match(Op1, m_ZeroInt()) || match(Op1, m_Undef())) + return BinaryOperator::CreateShl(Op0, ShAmtC); + + // fshl(0, X, C) --> lshr X, (BW-C) + // fshl(undef, X, C) --> lshr X, (BW-C) + if (match(Op0, m_ZeroInt()) || match(Op0, m_Undef())) + return BinaryOperator::CreateLShr(Op1, + ConstantExpr::getSub(WidthC, ShAmtC)); + + // fshl i16 X, X, 8 --> bswap i16 X (reduce to more-specific form) + if (Op0 == Op1 && BitWidth == 16 && match(ShAmtC, m_SpecificInt(8))) { + Module *Mod = II->getModule(); + Function *Bswap = Intrinsic::getDeclaration(Mod, Intrinsic::bswap, Ty); + return CallInst::Create(Bswap, { Op0 }); + } } + // Left or right might be masked. + if (SimplifyDemandedInstructionBits(*II)) + return &CI; + // The shift amount (operand 2) of a funnel shift is modulo the bitwidth, // so only the low bits of the shift amount are demanded if the bitwidth is // a power-of-2. - unsigned BitWidth = II->getType()->getScalarSizeInBits(); if (!isPowerOf2_32(BitWidth)) break; APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth)); @@ -2018,7 +2008,34 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::uadd_with_overflow: - case Intrinsic::sadd_with_overflow: + case Intrinsic::sadd_with_overflow: { + if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) + return I; + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; + + // Given 2 constant operands whose sum does not overflow: + // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 + // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 + Value *X; + const APInt *C0, *C1; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + bool IsSigned = IID == Intrinsic::sadd_with_overflow; + bool HasNWAdd = IsSigned ? match(Arg0, m_NSWAdd(m_Value(X), m_APInt(C0))) + : match(Arg0, m_NUWAdd(m_Value(X), m_APInt(C0))); + if (HasNWAdd && match(Arg1, m_APInt(C1))) { + bool Overflow; + APInt NewC = + IsSigned ? C1->sadd_ov(*C0, Overflow) : C1->uadd_ov(*C0, Overflow); + if (!Overflow) + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + IID, X, ConstantInt::get(Arg1->getType(), NewC))); + } + break; + } + case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) @@ -2026,16 +2043,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { LLVM_FALLTHROUGH; case Intrinsic::usub_with_overflow: + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; + break; + case Intrinsic::ssub_with_overflow: { - OverflowCheckFlavor OCF = - IntrinsicIDToOverflowCheckFlavor(II->getIntrinsicID()); - assert(OCF != OCF_INVALID && "unexpected!"); + if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) + return I; - Value *OperationResult = nullptr; - Constant *OverflowResult = nullptr; - if (OptimizeOverflowCheck(OCF, II->getArgOperand(0), II->getArgOperand(1), - *II, OperationResult, OverflowResult)) - return CreateOverflowTuple(II, OperationResult, OverflowResult); + Constant *C; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + // Given a constant C that is not the minimum signed value + // for an integer of a given bit width: + // + // ssubo X, C -> saddo X, -C + if (match(Arg1, m_Constant(C)) && C->isNotMinSignedValue()) { + Value *NegVal = ConstantExpr::getNeg(C); + // Build a saddo call that is equivalent to the discovered + // ssubo call. + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic(Intrinsic::sadd_with_overflow, + Arg0, NegVal)); + } break; } @@ -2047,39 +2077,32 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { LLVM_FALLTHROUGH; case Intrinsic::usub_sat: case Intrinsic::ssub_sat: { - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - Intrinsic::ID IID = II->getIntrinsicID(); + SaturatingInst *SI = cast<SaturatingInst>(II); + Type *Ty = SI->getType(); + Value *Arg0 = SI->getLHS(); + Value *Arg1 = SI->getRHS(); // Make use of known overflow information. - OverflowResult OR; - switch (IID) { - default: - llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::uadd_sat: - OR = computeOverflowForUnsignedAdd(Arg0, Arg1, II); - if (OR == OverflowResult::NeverOverflows) - return BinaryOperator::CreateNUWAdd(Arg0, Arg1); - if (OR == OverflowResult::AlwaysOverflows) - return replaceInstUsesWith(*II, - ConstantInt::getAllOnesValue(II->getType())); - break; - case Intrinsic::usub_sat: - OR = computeOverflowForUnsignedSub(Arg0, Arg1, II); - if (OR == OverflowResult::NeverOverflows) - return BinaryOperator::CreateNUWSub(Arg0, Arg1); - if (OR == OverflowResult::AlwaysOverflows) - return replaceInstUsesWith(*II, - ConstantInt::getNullValue(II->getType())); - break; - case Intrinsic::sadd_sat: - if (willNotOverflowSignedAdd(Arg0, Arg1, *II)) - return BinaryOperator::CreateNSWAdd(Arg0, Arg1); - break; - case Intrinsic::ssub_sat: - if (willNotOverflowSignedSub(Arg0, Arg1, *II)) - return BinaryOperator::CreateNSWSub(Arg0, Arg1); - break; + OverflowResult OR = computeOverflow(SI->getBinaryOp(), SI->isSigned(), + Arg0, Arg1, SI); + switch (OR) { + case OverflowResult::MayOverflow: + break; + case OverflowResult::NeverOverflows: + if (SI->isSigned()) + return BinaryOperator::CreateNSW(SI->getBinaryOp(), Arg0, Arg1); + else + return BinaryOperator::CreateNUW(SI->getBinaryOp(), Arg0, Arg1); + case OverflowResult::AlwaysOverflowsLow: { + unsigned BitWidth = Ty->getScalarSizeInBits(); + APInt Min = APSInt::getMinValue(BitWidth, !SI->isSigned()); + return replaceInstUsesWith(*SI, ConstantInt::get(Ty, Min)); + } + case OverflowResult::AlwaysOverflowsHigh: { + unsigned BitWidth = Ty->getScalarSizeInBits(); + APInt Max = APSInt::getMaxValue(BitWidth, !SI->isSigned()); + return replaceInstUsesWith(*SI, ConstantInt::get(Ty, Max)); + } } // ssub.sat(X, C) -> sadd.sat(X, -C) if C != MIN @@ -2101,7 +2124,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt NewVal; bool IsUnsigned = IID == Intrinsic::uadd_sat || IID == Intrinsic::usub_sat; - if (Other->getIntrinsicID() == II->getIntrinsicID() && + if (Other->getIntrinsicID() == IID && match(Arg1, m_APInt(Val)) && match(Other->getArgOperand(0), m_Value(X)) && match(Other->getArgOperand(1), m_APInt(Val2))) { @@ -2136,7 +2159,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return I; Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - Intrinsic::ID IID = II->getIntrinsicID(); Value *X, *Y; if (match(Arg0, m_FNeg(m_Value(X))) && match(Arg1, m_FNeg(m_Value(Y))) && (Arg0->hasOneUse() || Arg1->hasOneUse())) { @@ -2266,8 +2288,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *ExtSrc; if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc))))) { // Narrow the call: intrinsic (fpext x) -> fpext (intrinsic x) - Value *NarrowII = - Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), ExtSrc, II); + Value *NarrowII = Builder.CreateUnaryIntrinsic(IID, ExtSrc, II); return new FPExtInst(NarrowII, II->getType()); } break; @@ -2302,7 +2323,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { &DT) >= 16) { Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(Ptr); + return new LoadInst(II->getType(), Ptr); } break; case Intrinsic::ppc_vsx_lxvw4x: @@ -2310,7 +2331,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX loads into normal loads. Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(Ptr, Twine(""), false, 1); + return new LoadInst(II->getType(), Ptr, Twine(""), false, 1); } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: @@ -2338,7 +2359,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getType()->getVectorNumElements()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(VTy)); - Value *Load = Builder.CreateLoad(Ptr); + Value *Load = Builder.CreateLoad(VTy, Ptr); return new FPExtInst(Load, II->getType()); } break; @@ -2348,7 +2369,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { &DT) >= 32) { Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(Ptr); + return new LoadInst(II->getType(), Ptr); } break; case Intrinsic::ppc_qpx_qvstfs: @@ -2499,22 +2520,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_sse41_round_ps: - case Intrinsic::x86_sse41_round_pd: - case Intrinsic::x86_avx_round_ps_256: - case Intrinsic::x86_avx_round_pd_256: - case Intrinsic::x86_avx512_mask_rndscale_ps_128: - case Intrinsic::x86_avx512_mask_rndscale_ps_256: - case Intrinsic::x86_avx512_mask_rndscale_ps_512: - case Intrinsic::x86_avx512_mask_rndscale_pd_128: - case Intrinsic::x86_avx512_mask_rndscale_pd_256: - case Intrinsic::x86_avx512_mask_rndscale_pd_512: - case Intrinsic::x86_avx512_mask_rndscale_ss: - case Intrinsic::x86_avx512_mask_rndscale_sd: - if (Value *V = simplifyX86round(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - case Intrinsic::x86_mmx_pmovmskb: case Intrinsic::x86_sse_movmsk_ps: case Intrinsic::x86_sse2_movmsk_pd: @@ -2620,7 +2625,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Arg1 = II->getArgOperand(1); Value *V; - switch (II->getIntrinsicID()) { + switch (IID) { default: llvm_unreachable("Case stmts out of sync!"); case Intrinsic::x86_avx512_add_ps_512: case Intrinsic::x86_avx512_add_pd_512: @@ -2664,7 +2669,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *RHS = Builder.CreateExtractElement(Arg1, (uint64_t)0); Value *V; - switch (II->getIntrinsicID()) { + switch (IID) { default: llvm_unreachable("Case stmts out of sync!"); case Intrinsic::x86_avx512_mask_add_ss_round: case Intrinsic::x86_avx512_mask_add_sd_round: @@ -2706,44 +2711,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); } } - LLVM_FALLTHROUGH; - - // X86 scalar intrinsics simplified with SimplifyDemandedVectorElts. - case Intrinsic::x86_avx512_mask_max_ss_round: - case Intrinsic::x86_avx512_mask_min_ss_round: - case Intrinsic::x86_avx512_mask_max_sd_round: - case Intrinsic::x86_avx512_mask_min_sd_round: - case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse_min_ss: - case Intrinsic::x86_sse_max_ss: - case Intrinsic::x86_sse2_cmp_sd: - case Intrinsic::x86_sse2_min_sd: - case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_xop_vfrcz_ss: - case Intrinsic::x86_xop_vfrcz_sd: { - unsigned VWidth = II->getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); - if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { - if (V != II) - return replaceInstUsesWith(*II, V); - return II; - } - break; - } - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: { - unsigned VWidth = II->getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); - if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { - if (V != II) - return replaceInstUsesWith(*II, V); - return II; - } else if (Value *V = simplifyX86round(*II, Builder)) - return replaceInstUsesWith(*II, V); break; - } // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). @@ -2860,7 +2828,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packsswb: case Intrinsic::x86_avx512_packssdw_512: case Intrinsic::x86_avx512_packsswb_512: - if (Value *V = simplifyX86pack(*II, true)) + if (Value *V = simplifyX86pack(*II, Builder, true)) return replaceInstUsesWith(*II, V); break; @@ -2870,7 +2838,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_packuswb: case Intrinsic::x86_avx512_packusdw_512: case Intrinsic::x86_avx512_packuswb_512: - if (Value *V = simplifyX86pack(*II, false)) + if (Value *V = simplifyX86pack(*II, Builder, false)) return replaceInstUsesWith(*II, V); break; @@ -3168,19 +3136,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return nullptr; break; - case Intrinsic::x86_xop_vpcomb: - case Intrinsic::x86_xop_vpcomd: - case Intrinsic::x86_xop_vpcomq: - case Intrinsic::x86_xop_vpcomw: - if (Value *V = simplifyX86vpcom(*II, Builder, true)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_xop_vpcomub: - case Intrinsic::x86_xop_vpcomud: - case Intrinsic::x86_xop_vpcomuq: - case Intrinsic::x86_xop_vpcomuw: - if (Value *V = simplifyX86vpcom(*II, Builder, false)) + case Intrinsic::x86_addcarry_32: + case Intrinsic::x86_addcarry_64: + if (Value *V = simplifyX86addcarry(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -3296,8 +3254,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } // Check for constant LHS & RHS - in this case we just simplify. - bool Zext = (II->getIntrinsicID() == Intrinsic::arm_neon_vmullu || - II->getIntrinsicID() == Intrinsic::aarch64_neon_umull); + bool Zext = (IID == Intrinsic::arm_neon_vmullu || + IID == Intrinsic::aarch64_neon_umull); VectorType *NewVT = cast<VectorType>(II->getType()); if (Constant *CV0 = dyn_cast<Constant>(Arg0)) { if (Constant *CV1 = dyn_cast<Constant>(Arg1)) { @@ -3374,7 +3332,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APFloat Significand = frexp(C->getValueAPF(), Exp, APFloat::rmNearestTiesToEven); - if (II->getIntrinsicID() == Intrinsic::amdgcn_frexp_mant) { + if (IID == Intrinsic::amdgcn_frexp_mant) { return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Significand)); } @@ -3559,7 +3517,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } - bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe; + bool Signed = IID == Intrinsic::amdgcn_sbfe; if (!CWidth || !COffset) break; @@ -3587,15 +3545,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::amdgcn_exp: case Intrinsic::amdgcn_exp_compr: { - ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1)); - if (!En) // Illegal. - break; - + ConstantInt *En = cast<ConstantInt>(II->getArgOperand(1)); unsigned EnBits = En->getZExtValue(); if (EnBits == 0xf) break; // All inputs enabled. - bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr; + bool IsCompr = IID == Intrinsic::amdgcn_exp_compr; bool Changed = false; for (int I = 0; I < (IsCompr ? 2 : 4); ++I) { if ((!IsCompr && (EnBits & (1 << I)) == 0) || @@ -3680,13 +3635,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::amdgcn_icmp: case Intrinsic::amdgcn_fcmp: { - const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2)); - if (!CC) - break; - + const ConstantInt *CC = cast<ConstantInt>(II->getArgOperand(2)); // Guard against invalid arguments. int64_t CCVal = CC->getZExtValue(); - bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp; + bool IsInteger = IID == Intrinsic::amdgcn_icmp; if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE || CCVal > CmpInst::LAST_ICMP_PREDICATE)) || (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE || @@ -3709,7 +3661,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // register (which contains the bitmask of live threads). So a // comparison that always returns true is the same as a read of the // EXEC register. - Value *NewF = Intrinsic::getDeclaration( + Function *NewF = Intrinsic::getDeclaration( II->getModule(), Intrinsic::read_register, II->getType()); Metadata *MDArgs[] = {MDString::get(II->getContext(), "exec")}; MDNode *MD = MDNode::get(II->getContext(), MDArgs); @@ -3804,8 +3756,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) break; - Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID, - SrcLHS->getType()); + Function *NewF = + Intrinsic::getDeclaration(II->getModule(), NewIID, + { II->getType(), + SrcLHS->getType() }); Value *Args[] = { SrcLHS, SrcRHS, ConstantInt::get(CC->getType(), SrcPred) }; CallInst *NewCall = Builder.CreateCall(NewF, Args); @@ -3833,11 +3787,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::amdgcn_update_dpp: { Value *Old = II->getArgOperand(0); - auto BC = dyn_cast<ConstantInt>(II->getArgOperand(5)); - auto RM = dyn_cast<ConstantInt>(II->getArgOperand(3)); - auto BM = dyn_cast<ConstantInt>(II->getArgOperand(4)); - if (!BC || !RM || !BM || - BC->isZeroValue() || + auto BC = cast<ConstantInt>(II->getArgOperand(5)); + auto RM = cast<ConstantInt>(II->getArgOperand(3)); + auto BM = cast<ConstantInt>(II->getArgOperand(4)); + if (BC->isZeroValue() || RM->getZExtValue() != 0xF || BM->getZExtValue() != 0xF || isa<UndefValue>(Old)) @@ -3847,6 +3800,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->setOperand(0, UndefValue::get(Old->getType())); return II; } + case Intrinsic::amdgcn_readfirstlane: + case Intrinsic::amdgcn_readlane: { + // A constant value is trivially uniform. + if (Constant *C = dyn_cast<Constant>(II->getArgOperand(0))) + return replaceInstUsesWith(*II, C); + + // The rest of these may not be safe if the exec may not be the same between + // the def and use. + Value *Src = II->getArgOperand(0); + Instruction *SrcInst = dyn_cast<Instruction>(Src); + if (SrcInst && SrcInst->getParent() != II->getParent()) + break; + + // readfirstlane (readfirstlane x) -> readfirstlane x + // readlane (readfirstlane x), y -> readfirstlane x + if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readfirstlane>())) + return replaceInstUsesWith(*II, Src); + + if (IID == Intrinsic::amdgcn_readfirstlane) { + // readfirstlane (readlane x, y) -> readlane x, y + if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>())) + return replaceInstUsesWith(*II, Src); + } else { + // readlane (readlane x, y), y -> readlane x, y + if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>( + m_Value(), m_Specific(II->getArgOperand(1))))) + return replaceInstUsesWith(*II, Src); + } + + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -3870,14 +3854,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } if (CallInst *BCI = dyn_cast<CallInst>(BI)) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(BCI)) { + if (auto *II2 = dyn_cast<IntrinsicInst>(BCI)) { // If there is a stackrestore below this one, remove this one. - if (II->getIntrinsicID() == Intrinsic::stackrestore) + if (II2->getIntrinsicID() == Intrinsic::stackrestore) return eraseInstFromFunction(CI); // Bail if we cross over an intrinsic with side effects, such as // llvm.stacksave, llvm.read_register, or llvm.setjmp. - if (II->mayHaveSideEffects()) { + if (II2->mayHaveSideEffects()) { CannotRemove = true; break; } @@ -3920,16 +3904,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Canonicalize assume(a && b) -> assume(a); assume(b); // Note: New assumption intrinsics created here are registered by // the InstCombineIRInserter object. - Value *AssumeIntrinsic = II->getCalledValue(), *A, *B; + FunctionType *AssumeIntrinsicTy = II->getFunctionType(); + Value *AssumeIntrinsic = II->getCalledValue(); + Value *A, *B; if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { - Builder.CreateCall(AssumeIntrinsic, A, II->getName()); - Builder.CreateCall(AssumeIntrinsic, B, II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName()); return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { - Builder.CreateCall(AssumeIntrinsic, Builder.CreateNot(A), II->getName()); - Builder.CreateCall(AssumeIntrinsic, Builder.CreateNot(B), II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, + Builder.CreateNot(A), II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, + Builder.CreateNot(B), II->getName()); return eraseInstFromFunction(*II); } @@ -4036,7 +4024,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } } - return visitCallSite(II); + return visitCallBase(*II); } // Fence instruction simplification @@ -4051,12 +4039,17 @@ Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { // InvokeInst simplification Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { - return visitCallSite(&II); + return visitCallBase(II); +} + +// CallBrInst simplification +Instruction *InstCombiner::visitCallBrInst(CallBrInst &CBI) { + return visitCallBase(CBI); } /// If this cast does not affect the value passed through the varargs area, we /// can eliminate the use of the cast. -static bool isSafeToEliminateVarargsCast(const CallSite CS, +static bool isSafeToEliminateVarargsCast(const CallBase &Call, const DataLayout &DL, const CastInst *const CI, const int ix) { @@ -4068,18 +4061,20 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS, // TODO: This is probably something which should be expanded to all // intrinsics since the entire point of intrinsics is that // they are understandable by the optimizer. - if (isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)) + if (isStatepoint(&Call) || isGCRelocate(&Call) || isGCResult(&Call)) return false; // The size of ByVal or InAlloca arguments is derived from the type, so we // can't change to a type with a different size. If the size were // passed explicitly we could avoid this check. - if (!CS.isByValOrInAllocaArgument(ix)) + if (!Call.isByValOrInAllocaArgument(ix)) return true; Type* SrcTy = cast<PointerType>(CI->getOperand(0)->getType())->getElementType(); - Type* DstTy = cast<PointerType>(CI->getType())->getElementType(); + Type *DstTy = Call.isByValArgument(ix) + ? Call.getParamByValType(ix) + : cast<PointerType>(CI->getType())->getElementType(); if (!SrcTy->isSized() || !DstTy->isSized()) return false; if (DL.getTypeAllocSize(SrcTy) != DL.getTypeAllocSize(DstTy)) @@ -4096,7 +4091,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineErase = [this](Instruction *I) { eraseInstFromFunction(*I); }; - LibCallSimplifier Simplifier(DL, &TLI, ORE, InstCombineRAUW, + LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW, InstCombineErase); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; @@ -4182,10 +4177,10 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } -/// Improvements for call and invoke instructions. -Instruction *InstCombiner::visitCallSite(CallSite CS) { - if (isAllocLikeFn(CS.getInstruction(), &TLI)) - return visitAllocSite(*CS.getInstruction()); +/// Improvements for call, callbr and invoke instructions. +Instruction *InstCombiner::visitCallBase(CallBase &Call) { + if (isAllocLikeFn(&Call, &TLI)) + return visitAllocSite(Call); bool Changed = false; @@ -4195,52 +4190,50 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; - for (Value *V : CS.args()) { + for (Value *V : Call.args()) { if (V->getType()->isPointerTy() && - !CS.paramHasAttr(ArgNo, Attribute::NonNull) && - isKnownNonZero(V, DL, 0, &AC, CS.getInstruction(), &DT)) + !Call.paramHasAttr(ArgNo, Attribute::NonNull) && + isKnownNonZero(V, DL, 0, &AC, &Call, &DT)) ArgNos.push_back(ArgNo); ArgNo++; } - assert(ArgNo == CS.arg_size() && "sanity check"); + assert(ArgNo == Call.arg_size() && "sanity check"); if (!ArgNos.empty()) { - AttributeList AS = CS.getAttributes(); - LLVMContext &Ctx = CS.getInstruction()->getContext(); + AttributeList AS = Call.getAttributes(); + LLVMContext &Ctx = Call.getContext(); AS = AS.addParamAttribute(Ctx, ArgNos, Attribute::get(Ctx, Attribute::NonNull)); - CS.setAttributes(AS); + Call.setAttributes(AS); Changed = true; } // If the callee is a pointer to a function, attempt to move any casts to the - // arguments of the call/invoke. - Value *Callee = CS.getCalledValue(); - if (!isa<Function>(Callee) && transformConstExprCastCall(CS)) + // arguments of the call/callbr/invoke. + Value *Callee = Call.getCalledValue(); + if (!isa<Function>(Callee) && transformConstExprCastCall(Call)) return nullptr; if (Function *CalleeF = dyn_cast<Function>(Callee)) { // Remove the convergent attr on calls when the callee is not convergent. - if (CS.isConvergent() && !CalleeF->isConvergent() && + if (Call.isConvergent() && !CalleeF->isConvergent() && !CalleeF->isIntrinsic()) { - LLVM_DEBUG(dbgs() << "Removing convergent attr from instr " - << CS.getInstruction() << "\n"); - CS.setNotConvergent(); - return CS.getInstruction(); + LLVM_DEBUG(dbgs() << "Removing convergent attr from instr " << Call + << "\n"); + Call.setNotConvergent(); + return &Call; } // If the call and callee calling conventions don't match, this call must // be unreachable, as the call is undefined. - if (CalleeF->getCallingConv() != CS.getCallingConv() && + if (CalleeF->getCallingConv() != Call.getCallingConv() && // Only do this for calls to a function with a body. A prototype may // not actually end up matching the implementation's calling conv for a // variety of reasons (e.g. it may be written in assembly). !CalleeF->isDeclaration()) { - Instruction *OldCall = CS.getInstruction(); - new StoreInst(ConstantInt::getTrue(Callee->getContext()), - UndefValue::get(Type::getInt1PtrTy(Callee->getContext())), - OldCall); + Instruction *OldCall = &Call; + CreateNonTerminatorUnreachable(OldCall); // If OldCall does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!OldCall->getType()->isVoidTy()) @@ -4248,40 +4241,35 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { if (isa<CallInst>(OldCall)) return eraseInstFromFunction(*OldCall); - // We cannot remove an invoke, because it would change the CFG, just - // change the callee to a null pointer. - cast<InvokeInst>(OldCall)->setCalledFunction( - Constant::getNullValue(CalleeF->getType())); + // We cannot remove an invoke or a callbr, because it would change thexi + // CFG, just change the callee to a null pointer. + cast<CallBase>(OldCall)->setCalledFunction( + CalleeF->getFunctionType(), + Constant::getNullValue(CalleeF->getType())); return nullptr; } } if ((isa<ConstantPointerNull>(Callee) && - !NullPointerIsDefined(CS.getInstruction()->getFunction())) || + !NullPointerIsDefined(Call.getFunction())) || isa<UndefValue>(Callee)) { - // If CS does not return void then replaceAllUsesWith undef. + // If Call does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. - if (!CS.getInstruction()->getType()->isVoidTy()) - replaceInstUsesWith(*CS.getInstruction(), - UndefValue::get(CS.getInstruction()->getType())); + if (!Call.getType()->isVoidTy()) + replaceInstUsesWith(Call, UndefValue::get(Call.getType())); - if (isa<InvokeInst>(CS.getInstruction())) { - // Can't remove an invoke because we cannot change the CFG. + if (Call.isTerminator()) { + // Can't remove an invoke or callbr because we cannot change the CFG. return nullptr; } - // This instruction is not reachable, just remove it. We insert a store to - // undef so that we know that this code is not reachable, despite the fact - // that we can't modify the CFG here. - new StoreInst(ConstantInt::getTrue(Callee->getContext()), - UndefValue::get(Type::getInt1PtrTy(Callee->getContext())), - CS.getInstruction()); - - return eraseInstFromFunction(*CS.getInstruction()); + // This instruction is not reachable, just remove it. + CreateNonTerminatorUnreachable(&Call); + return eraseInstFromFunction(Call); } if (IntrinsicInst *II = findInitTrampoline(Callee)) - return transformCallThroughTrampoline(CS, II); + return transformCallThroughTrampoline(Call, *II); PointerType *PTy = cast<PointerType>(Callee->getType()); FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); @@ -4289,39 +4277,48 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { int ix = FTy->getNumParams(); // See if we can optimize any arguments passed through the varargs area of // the call. - for (CallSite::arg_iterator I = CS.arg_begin() + FTy->getNumParams(), - E = CS.arg_end(); I != E; ++I, ++ix) { + for (auto I = Call.arg_begin() + FTy->getNumParams(), E = Call.arg_end(); + I != E; ++I, ++ix) { CastInst *CI = dyn_cast<CastInst>(*I); - if (CI && isSafeToEliminateVarargsCast(CS, DL, CI, ix)) { + if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) { *I = CI->getOperand(0); + + // Update the byval type to match the argument type. + if (Call.isByValArgument(ix)) { + Call.removeParamAttr(ix, Attribute::ByVal); + Call.addParamAttr( + ix, Attribute::getWithByValType( + Call.getContext(), + CI->getOperand(0)->getType()->getPointerElementType())); + } Changed = true; } } } - if (isa<InlineAsm>(Callee) && !CS.doesNotThrow()) { + if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) { // Inline asm calls cannot throw - mark them 'nounwind'. - CS.setDoesNotThrow(); + Call.setDoesNotThrow(); Changed = true; } // Try to optimize the call if possible, we require DataLayout for most of // this. None of these calls are seen as possibly dead so go ahead and // delete the instruction now. - if (CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) { + if (CallInst *CI = dyn_cast<CallInst>(&Call)) { Instruction *I = tryOptimizeCall(CI); // If we changed something return the result, etc. Otherwise let // the fallthrough check. if (I) return eraseInstFromFunction(*I); } - return Changed ? CS.getInstruction() : nullptr; + return Changed ? &Call : nullptr; } /// If the callee is a constexpr cast of a function, attempt to move the cast to -/// the arguments of the call/invoke. -bool InstCombiner::transformConstExprCastCall(CallSite CS) { - auto *Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); +/// the arguments of the call/callbr/invoke. +bool InstCombiner::transformConstExprCastCall(CallBase &Call) { + auto *Callee = dyn_cast<Function>(Call.getCalledValue()->stripPointerCasts()); if (!Callee) return false; @@ -4335,11 +4332,11 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // prototype with the exception of pointee types. The code below doesn't // implement that, so we can't do this transform. // TODO: Do the transform if it only requires adding pointer casts. - if (CS.isMustTailCall()) + if (Call.isMustTailCall()) return false; - Instruction *Caller = CS.getInstruction(); - const AttributeList &CallerPAL = CS.getAttributes(); + Instruction *Caller = &Call; + const AttributeList &CallerPAL = Call.getAttributes(); // Okay, this is a cast from a function to a different type. Unless doing so // would cause a type conversion of one of our arguments, change this call to @@ -4370,20 +4367,24 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { return false; // Attribute not compatible with transformed value. } - // If the callsite is an invoke instruction, and the return value is used by - // a PHI node in a successor, we cannot change the return type of the call - // because there is no place to put the cast instruction (without breaking - // the critical edge). Bail out in this case. - if (!Caller->use_empty()) + // If the callbase is an invoke/callbr instruction, and the return value is + // used by a PHI node in a successor, we cannot change the return type of + // the call because there is no place to put the cast instruction (without + // breaking the critical edge). Bail out in this case. + if (!Caller->use_empty()) { if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) for (User *U : II->users()) if (PHINode *PN = dyn_cast<PHINode>(U)) if (PN->getParent() == II->getNormalDest() || PN->getParent() == II->getUnwindDest()) return false; + // FIXME: Be conservative for callbr to avoid a quadratic search. + if (isa<CallBrInst>(Caller)) + return false; + } } - unsigned NumActualArgs = CS.arg_size(); + unsigned NumActualArgs = Call.arg_size(); unsigned NumCommonArgs = std::min(FT->getNumParams(), NumActualArgs); // Prevent us turning: @@ -4398,7 +4399,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { Callee->getAttributes().hasAttrSomewhere(Attribute::ByVal)) return false; - CallSite::arg_iterator AI = CS.arg_begin(); + auto AI = Call.arg_begin(); for (unsigned i = 0, e = NumCommonArgs; i != e; ++i, ++AI) { Type *ParamTy = FT->getParamType(i); Type *ActTy = (*AI)->getType(); @@ -4410,7 +4411,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) return false; // Attribute not compatible with transformed value. - if (CS.isInAllocaArgument(i)) + if (Call.isInAllocaArgument(i)) return false; // Cannot transform to and from inalloca. // If the parameter is passed as a byval argument, then we have to have a @@ -4420,7 +4421,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!ParamPTy || !ParamPTy->getElementType()->isSized()) return false; - Type *CurElTy = ActTy->getPointerElementType(); + Type *CurElTy = Call.getParamByValType(i); if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamPTy->getElementType())) return false; @@ -4435,7 +4436,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // If the callee is just a declaration, don't change the varargsness of the // call. We don't want to introduce a varargs call where one doesn't // already exist. - PointerType *APTy = cast<PointerType>(CS.getCalledValue()->getType()); + PointerType *APTy = cast<PointerType>(Call.getCalledValue()->getType()); if (FT->isVarArg()!=cast<FunctionType>(APTy->getElementType())->isVarArg()) return false; @@ -4474,7 +4475,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // with the existing attributes. Wipe out any problematic attributes. RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy)); - AI = CS.arg_begin(); + LLVMContext &Ctx = Call.getContext(); + AI = Call.arg_begin(); for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { Type *ParamTy = FT->getParamType(i); @@ -4484,7 +4486,12 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { Args.push_back(NewArg); // Add any parameter attributes. - ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); + if (CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { + AttrBuilder AB(CallerPAL.getParamAttributes(i)); + AB.addByValAttr(NewArg->getType()->getPointerElementType()); + ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); + } else + ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); } // If the function takes more arguments than the call was taking, add them @@ -4523,45 +4530,50 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) && "missing argument attributes"); - LLVMContext &Ctx = Callee->getContext(); AttributeList NewCallerPAL = AttributeList::get( Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; - CS.getOperandBundlesAsDefs(OpBundles); + Call.getOperandBundlesAsDefs(OpBundles); - CallSite NewCS; + CallBase *NewCall; if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - NewCS = Builder.CreateInvoke(Callee, II->getNormalDest(), - II->getUnwindDest(), Args, OpBundles); + NewCall = Builder.CreateInvoke(Callee, II->getNormalDest(), + II->getUnwindDest(), Args, OpBundles); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { + NewCall = Builder.CreateCallBr(Callee, CBI->getDefaultDest(), + CBI->getIndirectDests(), Args, OpBundles); } else { - NewCS = Builder.CreateCall(Callee, Args, OpBundles); - cast<CallInst>(NewCS.getInstruction()) - ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind()); + NewCall = Builder.CreateCall(Callee, Args, OpBundles); + cast<CallInst>(NewCall)->setTailCallKind( + cast<CallInst>(Caller)->getTailCallKind()); } - NewCS->takeName(Caller); - NewCS.setCallingConv(CS.getCallingConv()); - NewCS.setAttributes(NewCallerPAL); + NewCall->takeName(Caller); + NewCall->setCallingConv(Call.getCallingConv()); + NewCall->setAttributes(NewCallerPAL); // Preserve the weight metadata for the new call instruction. The metadata // is used by SamplePGO to check callsite's hotness. uint64_t W; if (Caller->extractProfTotalWeight(W)) - NewCS->setProfWeight(W); + NewCall->setProfWeight(W); // Insert a cast of the return type as necessary. - Instruction *NC = NewCS.getInstruction(); + Instruction *NC = NewCall; Value *NV = NC; if (OldRetTy != NV->getType() && !Caller->use_empty()) { if (!NV->getType()->isVoidTy()) { NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); - // If this is an invoke instruction, we should insert it after the first - // non-phi, instruction in the normal successor block. + // If this is an invoke/callbr instruction, we should insert it after the + // first non-phi instruction in the normal successor block. if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { BasicBlock::iterator I = II->getNormalDest()->getFirstInsertionPt(); InsertNewInstBefore(NC, *I); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { + BasicBlock::iterator I = CBI->getDefaultDest()->getFirstInsertionPt(); + InsertNewInstBefore(NC, *I); } else { // Otherwise, it's a call, just insert cast right after the call. InsertNewInstBefore(NC, *Caller); @@ -4590,23 +4602,20 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { /// Turn a call to a function created by init_trampoline / adjust_trampoline /// intrinsic pair into a direct call to the underlying function. Instruction * -InstCombiner::transformCallThroughTrampoline(CallSite CS, - IntrinsicInst *Tramp) { - Value *Callee = CS.getCalledValue(); - PointerType *PTy = cast<PointerType>(Callee->getType()); - FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); - AttributeList Attrs = CS.getAttributes(); +InstCombiner::transformCallThroughTrampoline(CallBase &Call, + IntrinsicInst &Tramp) { + Value *Callee = Call.getCalledValue(); + Type *CalleeTy = Callee->getType(); + FunctionType *FTy = Call.getFunctionType(); + AttributeList Attrs = Call.getAttributes(); // If the call already has the 'nest' attribute somewhere then give up - // otherwise 'nest' would occur twice after splicing in the chain. if (Attrs.hasAttrSomewhere(Attribute::Nest)) return nullptr; - assert(Tramp && - "transformCallThroughTrampoline called with incorrect CallSite."); - - Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts()); - FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); + Function *NestF = cast<Function>(Tramp.getArgOperand(1)->stripPointerCasts()); + FunctionType *NestFTy = NestF->getFunctionType(); AttributeList NestAttrs = NestF->getAttributes(); if (!NestAttrs.isEmpty()) { @@ -4628,22 +4637,21 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, } if (NestTy) { - Instruction *Caller = CS.getInstruction(); std::vector<Value*> NewArgs; std::vector<AttributeSet> NewArgAttrs; - NewArgs.reserve(CS.arg_size() + 1); - NewArgAttrs.reserve(CS.arg_size()); + NewArgs.reserve(Call.arg_size() + 1); + NewArgAttrs.reserve(Call.arg_size()); // Insert the nest argument into the call argument list, which may // mean appending it. Likewise for attributes. { unsigned ArgNo = 0; - CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); + auto I = Call.arg_begin(), E = Call.arg_end(); do { if (ArgNo == NestArgNo) { // Add the chain argument and attributes. - Value *NestVal = Tramp->getArgOperand(2); + Value *NestVal = Tramp.getArgOperand(2); if (NestVal->getType() != NestTy) NestVal = Builder.CreateBitCast(NestVal, NestTy, "nest"); NewArgs.push_back(NestVal); @@ -4705,24 +4713,30 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Attrs.getRetAttributes(), NewArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; - CS.getOperandBundlesAsDefs(OpBundles); + Call.getOperandBundlesAsDefs(OpBundles); Instruction *NewCaller; - if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - NewCaller = InvokeInst::Create(NewCallee, + if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) { + NewCaller = InvokeInst::Create(NewFTy, NewCallee, II->getNormalDest(), II->getUnwindDest(), NewArgs, OpBundles); cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&Call)) { + NewCaller = + CallBrInst::Create(NewFTy, NewCallee, CBI->getDefaultDest(), + CBI->getIndirectDests(), NewArgs, OpBundles); + cast<CallBrInst>(NewCaller)->setCallingConv(CBI->getCallingConv()); + cast<CallBrInst>(NewCaller)->setAttributes(NewPAL); } else { - NewCaller = CallInst::Create(NewCallee, NewArgs, OpBundles); + NewCaller = CallInst::Create(NewFTy, NewCallee, NewArgs, OpBundles); cast<CallInst>(NewCaller)->setTailCallKind( - cast<CallInst>(Caller)->getTailCallKind()); + cast<CallInst>(Call).getTailCallKind()); cast<CallInst>(NewCaller)->setCallingConv( - cast<CallInst>(Caller)->getCallingConv()); + cast<CallInst>(Call).getCallingConv()); cast<CallInst>(NewCaller)->setAttributes(NewPAL); } - NewCaller->setDebugLoc(Caller->getDebugLoc()); + NewCaller->setDebugLoc(Call.getDebugLoc()); return NewCaller; } @@ -4731,9 +4745,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Replace the trampoline call with a direct call. Since there is no 'nest' // parameter, there is no need to adjust the argument list. Let the generic // code sort out any function type mismatches. - Constant *NewCallee = - NestF->getType() == PTy ? NestF : - ConstantExpr::getBitCast(NestF, PTy); - CS.setCalledFunction(NewCallee); - return CS.getInstruction(); + Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy); + Call.setCalledFunction(FTy, NewCallee); + return &Call; } diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 1201ac196ec0..2c9ba203fbf3 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1,9 +1,8 @@ //===- InstCombineCasts.cpp -----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -1373,10 +1372,8 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // If we know that the value being extended is positive, we can use a zext // instead. KnownBits Known = computeKnownBits(Src, 0, &CI); - if (Known.isNonNegative()) { - Value *ZExt = Builder.CreateZExt(Src, DestTy); - return replaceInstUsesWith(CI, ZExt); - } + if (Known.isNonNegative()) + return CastInst::Create(Instruction::ZExt, Src, DestTy); // Try to extend the entire expression tree to the wide destination type. if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { @@ -1618,12 +1615,20 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { return CastInst::CreateFPCast(ExactResult, Ty); } } + } - // (fptrunc (fneg x)) -> (fneg (fptrunc x)) - Value *X; - if (match(OpI, m_FNeg(m_Value(X)))) { + // (fptrunc (fneg x)) -> (fneg (fptrunc x)) + Value *X; + Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0)); + if (Op && Op->hasOneUse()) { + if (match(Op, m_FNeg(m_Value(X)))) { Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); - return BinaryOperator::CreateFNegFMF(InnerTrunc, OpI); + + // FIXME: Once we're sure that unary FNeg optimizations are on par with + // binary FNeg, this should always return a unary operator. + if (isa<BinaryOperator>(Op)) + return BinaryOperator::CreateFNegFMF(InnerTrunc, Op); + return UnaryOperator::CreateFNegFMF(InnerTrunc, Op); } } @@ -1657,8 +1662,8 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { II->getIntrinsicID(), Ty); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); - CallInst *NewCI = CallInst::Create(Overload, { InnerTrunc }, OpBundles, - II->getName()); + CallInst *NewCI = + CallInst::Create(Overload, {InnerTrunc}, OpBundles, II->getName()); NewCI->copyFastMathFlags(II); return NewCI; } @@ -2167,7 +2172,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { SmallSetVector<PHINode *, 4> OldPhiNodes; // Find all of the A->B casts and PHI nodes. - // We need to inpect all related PHI nodes, but PHIs can be cyclic, so + // We need to inspect all related PHI nodes, but PHIs can be cyclic, so // OldPhiNodes is used to track all known PHI nodes, before adding a new // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. PhiWorklist.push_back(PN); @@ -2242,20 +2247,43 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { } } + // Traverse all accumulated PHI nodes and process its users, + // which are Stores and BitcCasts. Without this processing + // NewPHI nodes could be replicated and could lead to extra + // moves generated after DeSSA. // If there is a store with type B, change it to type A. - for (User *U : PN->users()) { - auto *SI = dyn_cast<StoreInst>(U); - if (SI && SI->isSimple() && SI->getOperand(0) == PN) { - Builder.SetInsertPoint(SI); - auto *NewBC = - cast<BitCastInst>(Builder.CreateBitCast(NewPNodes[PN], SrcTy)); - SI->setOperand(0, NewBC); - Worklist.Add(SI); - assert(hasStoreUsersOnly(*NewBC)); + + + // Replace users of BitCast B->A with NewPHI. These will help + // later to get rid off a closure formed by OldPHI nodes. + Instruction *RetVal = nullptr; + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (User *V : OldPN->users()) { + if (auto *SI = dyn_cast<StoreInst>(V)) { + if (SI->isSimple() && SI->getOperand(0) == OldPN) { + Builder.SetInsertPoint(SI); + auto *NewBC = + cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); + SI->setOperand(0, NewBC); + Worklist.Add(SI); + assert(hasStoreUsersOnly(*NewBC)); + } + } + else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + // Verify it's a B->A cast. + Type *TyB = BCI->getOperand(0)->getType(); + Type *TyA = BCI->getType(); + if (TyA == DestTy && TyB == SrcTy) { + Instruction *I = replaceInstUsesWith(*BCI, NewPN); + if (BCI == &CI) + RetVal = I; + } + } } } - return replaceInstUsesWith(CI, NewPNodes[PN]); + return RetVal; } Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { @@ -2310,7 +2338,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If we found a path from the src to dest, create the getelementptr now. if (SrcElTy == DstElTy) { SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); - return GetElementPtrInst::CreateInBounds(Src, Idxs); + return GetElementPtrInst::CreateInBounds(SrcPTy->getElementType(), Src, + Idxs); } } @@ -2355,11 +2384,10 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } // Otherwise, see if our source is an insert. If so, then use the scalar - // component directly. - if (InsertElementInst *IEI = - dyn_cast<InsertElementInst>(CI.getOperand(0))) - return CastInst::Create(Instruction::BitCast, IEI->getOperand(1), - DestTy); + // component directly: + // bitcast (inselt <1 x elt> V, X, 0) to <n x m> --> bitcast X to <n x m> + if (auto *InsElt = dyn_cast<InsertElementInst>(Src)) + return new BitCastInst(InsElt->getOperand(1), DestTy); } } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index b5bbb09935e2..3a4283ae5406 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1,9 +1,8 @@ //===- InstCombineCompares.cpp --------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -704,7 +703,10 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, continue; if (auto *CI = dyn_cast<CastInst>(Val)) { - NewInsts[CI] = NewInsts[CI->getOperand(0)]; + // Don't get rid of the intermediate variable here; the store can grow + // the map which will invalidate the reference to the input value. + Value *V = NewInsts[CI->getOperand(0)]; + NewInsts[CI] = V; continue; } if (auto *GEP = dyn_cast<GEPOperator>(Val)) { @@ -1292,8 +1294,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // use the sadd_with_overflow intrinsic to efficiently compute both the // result and the overflow bit. Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::sadd_with_overflow, NewType); InstCombiner::BuilderTy &Builder = IC.Builder; @@ -1315,14 +1317,16 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } -// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +// Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); - Value *X = Cmp.getOperand(0); + if (!match(Cmp.getOperand(1), m_Zero())) + return nullptr; - if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (Pred == ICmpInst::ICMP_SGT) { Value *A, *B; - SelectPatternResult SPR = matchSelectPattern(X, A, B); + SelectPatternResult SPR = matchSelectPattern(Cmp.getOperand(0), A, B); if (SPR.Flavor == SPF_SMIN) { if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) return new ICmpInst(Pred, B, Cmp.getOperand(1)); @@ -1330,6 +1334,20 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { return new ICmpInst(Pred, A, Cmp.getOperand(1)); } } + + // Given: + // icmp eq/ne (urem %x, %y), 0 + // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': + // icmp eq/ne %x, 0 + Value *X, *Y; + if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) && + ICmpInst::isEquality(Pred)) { + KnownBits XKnown = computeKnownBits(X, 0, &Cmp); + KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); + if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + } + return nullptr; } @@ -1624,20 +1642,43 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1) { + bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; + // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 // TODO: We canonicalize to the longer form for scalars because we have // better analysis/folds for icmp, and codegen may be better with icmp. - if (Cmp.getPredicate() == CmpInst::ICMP_NE && Cmp.getType()->isVectorTy() && - C1.isNullValue() && match(And->getOperand(1), m_One())) + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && + match(And->getOperand(1), m_One())) return new TruncInst(And->getOperand(0), Cmp.getType()); const APInt *C2; - if (!match(And->getOperand(1), m_APInt(C2))) + Value *X; + if (!match(And, m_And(m_Value(X), m_APInt(C2)))) return nullptr; + // Don't perform the following transforms if the AND has multiple uses if (!And->hasOneUse()) return nullptr; + if (Cmp.isEquality() && C1.isNullValue()) { + // Restrict this fold to single-use 'and' (PR10267). + // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 + if (C2->isSignMask()) { + Constant *Zero = Constant::getNullValue(X->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, X, Zero); + } + + // Restrict this fold only for single-use 'and' (PR10267). + // ((%x & C) == 0) --> %x u< (-C) iff (-C) is power of two. + if ((~(*C2) + 1).isPowerOf2()) { + Constant *NegBOC = + ConstantExpr::getNeg(cast<Constant>(And->getOperand(1))); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, NegBOC); + } + } + // If the LHS is an 'and' of a truncate and we can widen the and/compare to // the input width without changing the value produced, eliminate the cast: // @@ -1772,13 +1813,22 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, ConstantInt::get(V->getType(), 1)); } - // X | C == C --> X <=u C - // X | C != C --> X >u C - // iff C+1 is a power of 2 (C is a bitmask of the low bits) - if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && - (C + 1).isPowerOf2()) { - Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; - return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); + Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); + if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) { + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + if ((C + 1).isPowerOf2()) { + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, OrOp0, OrOp1); + } + // More general: are all bits outside of a mask constant set or not set? + // X | C == C --> (X & ~C) == 0 + // X | C != C --> (X & ~C) != 0 + if (Or->hasOneUse()) { + Value *A = Builder.CreateAnd(OrOp0, ~C); + return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType())); + } } if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) @@ -1799,8 +1849,8 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // Are we using xors to bitwise check for a pair of (in)equalities? Convert to // a shorter form that has more potential to be folded even further. Value *X1, *X2, *X3, *X4; - if (match(Or->getOperand(0), m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && - match(Or->getOperand(1), m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { + if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && + match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); @@ -1994,6 +2044,27 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, And, Constant::getNullValue(ShType)); } + // Simplify 'shl' inequality test into 'and' equality test. + if (Cmp.isUnsigned() && Shl->hasOneUse()) { + // (X l<< C2) u<=/u> C1 iff C1+1 is power of two -> X & (~C1 l>> C2) ==/!= 0 + if ((C + 1).isPowerOf2() && + (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) { + Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + // (X l<< C2) u</u>= C1 iff C1 is power of two -> X & (-C1 l>> C2) ==/!= 0 + if (C.isPowerOf2() && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { + Value *And = + Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue())); + return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(ShType)); + } + } + // Transform (icmp pred iM (shl iM %v, N), C) // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. @@ -2313,6 +2384,16 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); + const APInt *C2; + APInt SubResult; + + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + if (match(X, m_APInt(C2)) && + ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || + (Cmp.isSigned() && Sub->hasNoSignedWrap())) && + !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) + return new ICmpInst(Cmp.getSwappedPredicate(), Y, + ConstantInt::get(Y->getType(), SubResult)); // The following transforms are only worth it if the only user of the subtract // is the icmp. @@ -2337,7 +2418,6 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } - const APInt *C2; if (!match(X, m_APInt(C2))) return nullptr; @@ -2482,20 +2562,76 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, // the entire original Cmp can be simplified to a false. Value *Cond = Builder.getFalse(); if (TrueWhenLessThan) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, + OrigLHS, OrigRHS)); if (TrueWhenEqual) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, + OrigLHS, OrigRHS)); if (TrueWhenGreaterThan) - Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, + OrigLHS, OrigRHS)); return replaceInstUsesWith(Cmp, Cond); } return nullptr; } -Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, - BitCastInst *Bitcast, - const APInt &C) { +static Instruction *foldICmpBitCast(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); + if (!Bitcast) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op1 = Cmp.getOperand(1); + Value *BCSrcOp = Bitcast->getOperand(0); + + // Make sure the bitcast doesn't change the number of vector elements. + if (Bitcast->getSrcTy()->getScalarSizeInBits() == + Bitcast->getDestTy()->getScalarSizeInBits()) { + // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. + Value *X; + if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { + // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 + // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 + // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && + match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + + // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 + if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) + return new ICmpInst(Pred, X, + ConstantInt::getAllOnesValue(X->getType())); + } + + // Zero-equality checks are preserved through unsigned floating-point casts: + // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 + if (match(BCSrcOp, m_UIToFP(m_Value(X)))) + if (Cmp.isEquality() && match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + } + + // Test to see if the operands of the icmp are casted versions of other + // values. If the ptr->ptr cast can be stripped off both arguments, do so. + if (Bitcast->getType()->isPointerTy() && + (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { + // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast + // so eliminate it as well. + if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) + Op1 = BC2->getOperand(0); + + Op1 = Builder.CreateBitCast(Op1, BCSrcOp->getType()); + return new ICmpInst(Pred, BCSrcOp, Op1); + } + // Folding: icmp <pred> iN X, C // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN // and C is a splat of a K-bit pattern @@ -2503,28 +2639,28 @@ Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, // Into: // %E = extractelement <M x iK> %vec, i32 C' // icmp <pred> iK %E, trunc(C) - if (!Bitcast->getType()->isIntegerTy() || + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || + !Bitcast->getType()->isIntegerTy() || !Bitcast->getSrcTy()->isIntOrIntVectorTy()) return nullptr; - Value *BCIOp = Bitcast->getOperand(0); - Value *Vec = nullptr; // 1st vector arg of the shufflevector - Constant *Mask = nullptr; // Mask arg of the shufflevector - if (match(BCIOp, + Value *Vec; + Constant *Mask; + if (match(BCSrcOp, m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { // Check whether every element of Mask is the same constant if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { - auto *VecTy = cast<VectorType>(BCIOp->getType()); + auto *VecTy = cast<VectorType>(BCSrcOp->getType()); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); - auto Pred = Cmp.getPredicate(); - if (C.isSplat(EltTy->getBitWidth())) { + if (C->isSplat(EltTy->getBitWidth())) { // Fold the icmp based on the value of C // If C is M copies of an iK sized bit pattern, // then: // => %E = extractelement <N x iK> %vec, i32 Elem // icmp <pred> iK %SplatVal, <pattern> Value *Extract = Builder.CreateExtractElement(Vec, Elem); - Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth())); + Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); return new ICmpInst(Pred, Extract, NewC); } } @@ -2606,13 +2742,9 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } - if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) { - if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C)) + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) return I; - } - - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) - return I; return nullptr; } @@ -2711,24 +2843,6 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, if (C == *BOC && C.isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); - - // Don't perform the following transforms if the AND has multiple uses - if (!BO->hasOneUse()) - break; - - // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->isSignMask()) { - Constant *Zero = Constant::getNullValue(BOp0->getType()); - auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(NewPred, BOp0, Zero); - } - - // ((X & ~7) == 0) --> X < 8 - if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { - Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); - auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(NewPred, BOp0, NegBOC); - } } break; } @@ -2756,14 +2870,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, return nullptr; } -/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. -Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - const APInt &C) { - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); - if (!II || !Cmp.isEquality()) - return nullptr; - - // Handle icmp {eq|ne} <intrinsic>, Constant. +/// Fold an equality icmp with LLVM intrinsic and constant operand. +Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { @@ -2823,6 +2933,65 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, return nullptr; } +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { + if (Cmp.isEquality()) + return foldICmpEqIntrinsicWithConstant(Cmp, II, C); + + Type *Ty = II->getType(); + unsigned BitWidth = C.getBitWidth(); + switch (II->getIntrinsicID()) { + case Intrinsic::ctlz: { + // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + + // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && + C.uge(1) && C.ule(BitWidth)) { + unsigned Num = C.getLimitedValue(); + APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, + II->getArgOperand(0), ConstantInt::get(Ty, Limit)); + } + break; + } + case Intrinsic::cttz: { + // Limit to one use to ensure we don't increase instruction count. + if (!II->hasOneUse()) + return nullptr; + + // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + + // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && + C.uge(1) && C.ule(BitWidth)) { + APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, + Builder.CreateAnd(II->getArgOperand(0), Mask), + ConstantInt::getNullValue(Ty)); + } + break; + } + default: + break; + } + + return nullptr; +} + /// Handle icmp with constant (but not simple integer constant) RHS. Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2983,6 +3152,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x s> x & (-1 >> y) -> x s> (-1 >> y) if (X != I.getOperand(0)) // X must be on LHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SGT; break; case ICmpInst::Predicate::ICMP_SGE: @@ -3009,6 +3182,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x s<= x & (-1 >> y) -> x s<= (-1 >> y) if (X != I.getOperand(0)) // X must be on LHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SLE; break; default: @@ -3093,6 +3270,64 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I, return T1; } +// Given pattern: +// icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 +// we should move shifts to the same hand of 'and', i.e. rewrite as +// icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) +// We are only interested in opposite logical shifts here. +// If we can, we want to end up creating 'lshr' shift. +static Value * +foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, + InstCombiner::BuilderTy &Builder) { + if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) || + !I.getOperand(0)->hasOneUse()) + return nullptr; + + auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); + auto m_AnyLShr = m_LShr(m_Value(), m_Value()); + + // Look for an 'and' of two (opposite) logical shifts. + // Pick the single-use shift as XShift. + Value *XShift, *YShift; + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), + m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) + return nullptr; + + // If YShift is a single-use 'lshr', swap the shifts around. + if (match(YShift, m_OneUse(m_AnyLShr))) + std::swap(XShift, YShift); + + // The shifts must be in opposite directions. + Instruction::BinaryOps XShiftOpcode = + cast<BinaryOperator>(XShift)->getOpcode(); + if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) + return nullptr; // Do not care about same-direction shifts here. + + Value *X, *XShAmt, *Y, *YShAmt; + match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); + match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + + // Can we fold (XShAmt+YShAmt) ? + Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, + SQ.getWithInstruction(&I)); + if (!NewShAmt) + return nullptr; + // Is the new shift amount smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. The shift is the same that was for X. + Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr + ? Builder.CreateLShr(X, NewShAmt) + : Builder.CreateShl(X, NewShAmt); + Value *T1 = Builder.CreateAnd(T0, Y); + return Builder.CreateICmp(I.getPredicate(), T1, + Constant::getNullValue(X->getType())); +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -3448,6 +3683,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) return replaceInstUsesWith(I, V); + if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder)) + return replaceInstUsesWith(I, V); + return nullptr; } @@ -3688,6 +3926,30 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { match(Op1, m_BitReverse(m_Value(B))))) return new ICmpInst(Pred, A, B); + // Canonicalize checking for a power-of-2-or-zero value: + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return Pred == ICmpInst::ICMP_EQ + ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); + } + return nullptr; } @@ -3698,7 +3960,6 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); Type *DestTy = LHSCI->getType(); - Value *RHSCIOp; // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. @@ -3740,7 +4001,7 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { // Not an extension from the same type? - RHSCIOp = CI->getOperand(0); + Value *RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) return nullptr; @@ -3813,104 +4074,83 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, - Value *RHS, Instruction &OrigI, - Value *&Result, Constant *&Overflow) { +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + case Instruction::Sub: + return match(RHS, m_Zero()); + case Instruction::Mul: + return match(RHS, m_One()); + } +} + +OverflowResult InstCombiner::computeOverflow( + Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, Instruction *CxtI) const { + switch (BinaryOp) { + default: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: + if (IsSigned) + return computeOverflowForSignedAdd(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedAdd(LHS, RHS, CxtI); + case Instruction::Sub: + if (IsSigned) + return computeOverflowForSignedSub(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedSub(LHS, RHS, CxtI); + case Instruction::Mul: + if (IsSigned) + return computeOverflowForSignedMul(LHS, RHS, CxtI); + else + return computeOverflowForUnsignedMul(LHS, RHS, CxtI); + } +} + +bool InstCombiner::OptimizeOverflowCheck( + Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, + Instruction &OrigI, Value *&Result, Constant *&Overflow) { if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS)) std::swap(LHS, RHS); - auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { - Result = OpResult; - Overflow = OverflowVal; - if (ReuseName) - Result->takeName(&OrigI); - return true; - }; - // If the overflow check was an add followed by a compare, the insertion point // may be pointing to the compare. We want to insert the new instructions // before the add in case there are uses of the add between the add and the // compare. Builder.SetInsertPoint(&OrigI); - switch (OCF) { - case OCF_INVALID: - llvm_unreachable("bad overflow check kind!"); - - case OCF_UNSIGNED_ADD: { - OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); - if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder.CreateNUWAdd(LHS, RHS), Builder.getFalse(), - true); - - if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder.CreateAdd(LHS, RHS), Builder.getTrue(), true); - - // Fall through uadd into sadd - LLVM_FALLTHROUGH; - } - case OCF_SIGNED_ADD: { - // X + 0 -> {X, false} - if (match(RHS, m_Zero())) - return SetResult(LHS, Builder.getFalse(), false); - - // We can strength reduce this signed add into a regular add if we can prove - // that it will never overflow. - if (OCF == OCF_SIGNED_ADD) - if (willNotOverflowSignedAdd(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWAdd(LHS, RHS), Builder.getFalse(), - true); - break; - } - - case OCF_UNSIGNED_SUB: - case OCF_SIGNED_SUB: { - // X - 0 -> {X, false} - if (match(RHS, m_Zero())) - return SetResult(LHS, Builder.getFalse(), false); - - if (OCF == OCF_SIGNED_SUB) { - if (willNotOverflowSignedSub(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWSub(LHS, RHS), Builder.getFalse(), - true); - } else { - if (willNotOverflowUnsignedSub(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNUWSub(LHS, RHS), Builder.getFalse(), - true); - } - break; - } - - case OCF_UNSIGNED_MUL: { - OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); - if (OR == OverflowResult::NeverOverflows) - return SetResult(Builder.CreateNUWMul(LHS, RHS), Builder.getFalse(), - true); - if (OR == OverflowResult::AlwaysOverflows) - return SetResult(Builder.CreateMul(LHS, RHS), Builder.getTrue(), true); - LLVM_FALLTHROUGH; + if (isNeutralValue(BinaryOp, RHS)) { + Result = LHS; + Overflow = Builder.getFalse(); + return true; } - case OCF_SIGNED_MUL: - // X * undef -> undef - if (isa<UndefValue>(RHS)) - return SetResult(RHS, UndefValue::get(Builder.getInt1Ty()), false); - - // X * 0 -> {0, false} - if (match(RHS, m_Zero())) - return SetResult(RHS, Builder.getFalse(), false); - - // X * 1 -> {X, false} - if (match(RHS, m_One())) - return SetResult(LHS, Builder.getFalse(), false); - if (OCF == OCF_SIGNED_MUL) - if (willNotOverflowSignedMul(LHS, RHS, OrigI)) - return SetResult(Builder.CreateNSWMul(LHS, RHS), Builder.getFalse(), - true); - break; + switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) { + case OverflowResult::MayOverflow: + return false; + case OverflowResult::AlwaysOverflowsLow: + case OverflowResult::AlwaysOverflowsHigh: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = Builder.getTrue(); + return true; + case OverflowResult::NeverOverflows: + Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); + Result->takeName(&OrigI); + Overflow = Builder.getFalse(); + if (auto *Inst = dyn_cast<Instruction>(Result)) { + if (IsSigned) + Inst->setHasNoSignedWrap(); + else + Inst->setHasNoUnsignedWrap(); + } + return true; } - return false; + llvm_unreachable("Unexpected overflow result"); } /// Recognize and process idiom involving test for multiplication @@ -4084,8 +4324,8 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, MulA = Builder.CreateZExt(A, MulType); if (WidthB < MulWidth) MulB = Builder.CreateZExt(B, MulType); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::umul_with_overflow, MulType); + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, MulType); CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); IC.Worklist.Add(MulInstr); @@ -4881,61 +5121,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return New; } - // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. - Value *X; - if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) { - // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 - // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 - // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 - // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 - if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || - Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && - match(Op1, m_Zero())) - return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - - // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 - if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); - - // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 - if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) - return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType())); - } - - // Zero-equality checks are preserved through unsigned floating-point casts: - // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 - // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 - if (match(Op0, m_BitCast(m_UIToFP(m_Value(X))))) - if (I.isEquality() && match(Op1, m_Zero())) - return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - - // Test to see if the operands of the icmp are casted versions of other - // values. If the ptr->ptr cast can be stripped off both arguments, we do so - // now. - if (BitCastInst *CI = dyn_cast<BitCastInst>(Op0)) { - if (Op0->getType()->isPointerTy() && - (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { - // We keep moving the cast from the left operand over to the right - // operand, where it can often be eliminated completely. - Op0 = CI->getOperand(0); - - // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast - // so eliminate it as well. - if (BitCastInst *CI2 = dyn_cast<BitCastInst>(Op1)) - Op1 = CI2->getOperand(0); - - // If Op1 is a constant, we can fold the cast into the constant. - if (Op0->getType() != Op1->getType()) { - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { - Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); - } else { - // Otherwise, cast the RHS right before the icmp - Op1 = Builder.CreateBitCast(Op1, Op0->getType()); - } - } - return new ICmpInst(I.getPredicate(), Op0, Op1); - } - } + if (Instruction *Res = foldICmpBitCast(I, Builder)) + return Res; if (isa<CastInst>(Op0)) { // Handle the special case of: icmp (cast bool to X), <cst> @@ -4984,8 +5171,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { isa<IntegerType>(A->getType())) { Value *Result; Constant *Overflow; - if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, - Overflow)) { + if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B, + *AddI, Result, Overflow)) { replaceInstUsesWith(*AddI, Result); return replaceInstUsesWith(I, Overflow); } @@ -5411,6 +5598,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' + Type *OpType = Op0->getType(); + assert(OpType == Op1->getType() && "fcmp with different-typed operands?"); if (Op0 == Op1) { switch (Pred) { default: break; @@ -5420,7 +5609,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_UNE: // True if unordered or not equal // Canonicalize these to be 'fcmp uno %X, 0.0'. I.setPredicate(FCmpInst::FCMP_UNO); - I.setOperand(1, Constant::getNullValue(Op0->getType())); + I.setOperand(1, Constant::getNullValue(OpType)); return &I; case FCmpInst::FCMP_ORD: // True if ordered (no nans) @@ -5429,7 +5618,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_OLE: // True if ordered and less than or equal // Canonicalize these to be 'fcmp ord %X, 0.0'. I.setPredicate(FCmpInst::FCMP_ORD); - I.setOperand(1, Constant::getNullValue(Op0->getType())); + I.setOperand(1, Constant::getNullValue(OpType)); return &I; } } @@ -5438,15 +5627,20 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { - I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); + I.setOperand(0, ConstantFP::getNullValue(OpType)); return &I; } if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { - I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); + I.setOperand(1, ConstantFP::getNullValue(OpType)); return &I; } } + // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -5465,7 +5659,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { - I.setOperand(1, ConstantFP::getNullValue(Op1->getType())); + I.setOperand(1, ConstantFP::getNullValue(OpType)); return &I; } @@ -5505,12 +5699,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Instruction *R = foldFabsWithFcmpZero(I)) return R; - Value *X, *Y; if (match(Op0, m_FNeg(m_Value(X)))) { - // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); - // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C Constant *C; if (match(Op1, m_Constant(C))) { diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 2de41bd5bef5..434b0d591215 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -1,9 +1,8 @@ //===- InstCombineInternal.h - InstCombine pass internals -------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -53,13 +52,14 @@ namespace llvm { class APInt; class AssumptionCache; -class CallSite; +class BlockFrequencyInfo; class DataLayout; class DominatorTree; class GEPOperator; class GlobalVariable; class LoopInfo; class OptimizationRemarkEmitter; +class ProfileSummaryInfo; class TargetLibraryInfo; class User; @@ -185,40 +185,6 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } -/// Specific patterns of overflow check idioms that we match. -enum OverflowCheckFlavor { - OCF_UNSIGNED_ADD, - OCF_SIGNED_ADD, - OCF_UNSIGNED_SUB, - OCF_SIGNED_SUB, - OCF_UNSIGNED_MUL, - OCF_SIGNED_MUL, - - OCF_INVALID -}; - -/// Returns the OverflowCheckFlavor corresponding to a overflow_with_op -/// intrinsic. -static inline OverflowCheckFlavor -IntrinsicIDToOverflowCheckFlavor(unsigned ID) { - switch (ID) { - default: - return OCF_INVALID; - case Intrinsic::uadd_with_overflow: - return OCF_UNSIGNED_ADD; - case Intrinsic::sadd_with_overflow: - return OCF_SIGNED_ADD; - case Intrinsic::usub_with_overflow: - return OCF_UNSIGNED_SUB; - case Intrinsic::ssub_with_overflow: - return OCF_SIGNED_SUB; - case Intrinsic::umul_with_overflow: - return OCF_UNSIGNED_MUL; - case Intrinsic::smul_with_overflow: - return OCF_SIGNED_MUL; - } -} - /// Some binary operators require special handling to avoid poison and undefined /// behavior. If a constant vector has undef elements, replace those undefs with /// identity constants if possible because those are always safe to execute. @@ -306,6 +272,8 @@ private: const DataLayout &DL; const SimplifyQuery SQ; OptimizationRemarkEmitter &ORE; + BlockFrequencyInfo *BFI; + ProfileSummaryInfo *PSI; // Optional analyses. When non-null, these can both be used to do better // combining and will be updated to reflect any changes. @@ -317,11 +285,11 @@ public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - OptimizationRemarkEmitter &ORE, const DataLayout &DL, - LoopInfo *LI) + OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} + DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {} /// Run the combiner over the entire worklist until it is empty. /// @@ -345,6 +313,7 @@ public: // I - Change was made, I is still valid, I may be dead though // otherwise - Change was made, replace I with returned instruction // + Instruction *visitFNeg(UnaryOperator &I); Instruction *visitAdd(BinaryOperator &I); Instruction *visitFAdd(BinaryOperator &I); Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); @@ -394,6 +363,7 @@ public: Instruction *visitSelectInst(SelectInst &SI); Instruction *visitCallInst(CallInst &CI); Instruction *visitInvokeInst(InvokeInst &II); + Instruction *visitCallBrInst(CallBrInst &CBI); Instruction *SliceUpIllegalIntegerPHI(PHINode &PN); Instruction *visitPHINode(PHINode &PN); @@ -403,6 +373,7 @@ public: Instruction *visitFree(CallInst &FI); Instruction *visitLoadInst(LoadInst &LI); Instruction *visitStoreInst(StoreInst &SI); + Instruction *visitAtomicRMWInst(AtomicRMWInst &SI); Instruction *visitBranchInst(BranchInst &BI); Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); @@ -464,16 +435,22 @@ private: /// operation in OperationResult and result of the overflow check in /// OverflowResult, and return true. If no simplification is possible, /// returns false. - bool OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, Value *RHS, + bool OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, Instruction &CtxI, Value *&OperationResult, Constant *&OverflowResult); - Instruction *visitCallSite(CallSite CS); + Instruction *visitCallBase(CallBase &Call); Instruction *tryOptimizeCall(CallInst *CI); - bool transformConstExprCastCall(CallSite CS); - Instruction *transformCallThroughTrampoline(CallSite CS, - IntrinsicInst *Tramp); - + bool transformConstExprCastCall(CallBase &Call); + Instruction *transformCallThroughTrampoline(CallBase &Call, + IntrinsicInst &Tramp); + + Value *simplifyMaskedLoad(IntrinsicInst &II); + Instruction *simplifyMaskedStore(IntrinsicInst &II); + Instruction *simplifyMaskedGather(IntrinsicInst &II); + Instruction *simplifyMaskedScatter(IntrinsicInst &II); + /// Transform (zext icmp) to bitwise / integer operations in order to /// eliminate it. /// @@ -592,6 +569,8 @@ private: Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); Value *getSelectCondition(Value *A, Value *B); + Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); + public: /// Inserts an instruction \p New before instruction \p Old /// @@ -647,6 +626,16 @@ public: return InsertValueInst::Create(Struct, Result, 0); } + /// Create and insert the idiom we use to indicate a block is unreachable + /// without having to rewrite the CFG from within InstCombine. + void CreateNonTerminatorUnreachable(Instruction *InsertAt) { + auto &Ctx = InsertAt->getContext(); + new StoreInst(ConstantInt::getTrue(Ctx), + UndefValue::get(Type::getInt1PtrTy(Ctx)), + InsertAt); + } + + /// Combiner aware instruction erasure. /// /// When dealing with an instruction that has side effects or produces a void @@ -703,7 +692,7 @@ public: } OverflowResult computeOverflowForSignedMul(const Value *LHS, - const Value *RHS, + const Value *RHS, const Instruction *CxtI) const { return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } @@ -731,6 +720,10 @@ public: return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflow( + Instruction::BinaryOps BinaryOp, bool IsSigned, + Value *LHS, Value *RHS, Instruction *CxtI) const; + /// Maximum size of array considered when transforming. uint64_t MaxArraySizeForCombine; @@ -802,8 +795,7 @@ private: Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, - int DmaskIdx = -1, - int TFCIdx = -1); + int DmaskIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0); @@ -868,8 +860,6 @@ private: Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); - Instruction *foldICmpBitCastConstant(ICmpInst &Cmp, BitCastInst *Bitcast, - const APInt &C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, @@ -904,7 +894,10 @@ private: Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, const APInt &C); - Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt &C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, + const APInt &C); + Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, + const APInt &C); // Helpers of visitSelectInst(). Instruction *foldSelectExtConst(SelectInst &Sel); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 76ab614090fa..054fb7da09a2 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -1,9 +1,8 @@ //===- InstCombineLoadStoreAlloca.cpp -------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -89,29 +88,29 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, continue; } - if (auto CS = CallSite(I)) { + if (auto *Call = dyn_cast<CallBase>(I)) { // If this is the function being called then we treat it like a load and // ignore it. - if (CS.isCallee(&U)) + if (Call->isCallee(&U)) continue; - unsigned DataOpNo = CS.getDataOperandNo(&U); - bool IsArgOperand = CS.isArgOperand(&U); + unsigned DataOpNo = Call->getDataOperandNo(&U); + bool IsArgOperand = Call->isArgOperand(&U); // Inalloca arguments are clobbered by the call. - if (IsArgOperand && CS.isInAllocaArgument(DataOpNo)) + if (IsArgOperand && Call->isInAllocaArgument(DataOpNo)) return false; // If this is a readonly/readnone call site, then we know it is just a // load (but one that potentially returns the value itself), so we can // ignore it if we know that the value isn't captured. - if (CS.onlyReadsMemory() && - (CS.getInstruction()->use_empty() || CS.doesNotCapture(DataOpNo))) + if (Call->onlyReadsMemory() && + (Call->use_empty() || Call->doesNotCapture(DataOpNo))) continue; // If this is being passed as a byval argument, the caller is making a // copy, so it is only a read of the alloca. - if (IsArgOperand && CS.isByValArgument(DataOpNo)) + if (IsArgOperand && Call->isByValArgument(DataOpNo)) continue; } @@ -213,8 +212,8 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); Value *NullIdx = Constant::getNullValue(IdxTy); Value *Idx[2] = {NullIdx, NullIdx}; - Instruction *GEP = - GetElementPtrInst::CreateInBounds(New, Idx, New->getName() + ".sub"); + Instruction *GEP = GetElementPtrInst::CreateInBounds( + NewTy, New, Idx, New->getName() + ".sub"); IC.InsertNewInstBefore(GEP, *It); // Now make everything use the getelementptr instead of the original @@ -299,7 +298,7 @@ void PointerReplacer::replace(Instruction *I) { if (auto *LT = dyn_cast<LoadInst>(I)) { auto *V = getReplacement(LT->getPointerOperand()); assert(V && "Operand not replaced"); - auto *NewI = new LoadInst(V); + auto *NewI = new LoadInst(I->getType(), V); NewI->takeName(LT); IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); @@ -466,7 +465,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( - NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { @@ -631,7 +630,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && - DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && + DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && !isMinMaxWithLoads( peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { @@ -725,7 +724,8 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), Name + ".elt"); auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); - auto *L = IC.Builder.CreateAlignedLoad(Ptr, EltAlign, Name + ".unpack"); + auto *L = IC.Builder.CreateAlignedLoad(ST->getElementType(i), Ptr, + EltAlign, Name + ".unpack"); // Propagate AA metadata. It'll still be valid on the narrowed load. AAMDNodes AAMD; LI.getAAMetadata(AAMD); @@ -775,8 +775,8 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { }; auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), Name + ".elt"); - auto *L = IC.Builder.CreateAlignedLoad(Ptr, MinAlign(Align, Offset), - Name + ".unpack"); + auto *L = IC.Builder.CreateAlignedLoad( + AT->getElementType(), Ptr, MinAlign(Align, Offset), Name + ".unpack"); AAMDNodes AAMD; LI.getAAMetadata(AAMD); L->setAAMetadata(AAMD); @@ -1064,12 +1064,16 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). unsigned Align = LI.getAlignment(); - if (isSafeToLoadUnconditionally(SI->getOperand(1), Align, DL, SI) && - isSafeToLoadUnconditionally(SI->getOperand(2), Align, DL, SI)) { - LoadInst *V1 = Builder.CreateLoad(SI->getOperand(1), - SI->getOperand(1)->getName()+".val"); - LoadInst *V2 = Builder.CreateLoad(SI->getOperand(2), - SI->getOperand(2)->getName()+".val"); + if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), Align, + DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), Align, + DL, SI)) { + LoadInst *V1 = + Builder.CreateLoad(LI.getType(), SI->getOperand(1), + SI->getOperand(1)->getName() + ".val"); + LoadInst *V2 = + Builder.CreateLoad(LI.getType(), SI->getOperand(2), + SI->getOperand(2)->getName() + ".val"); assert(LI.isUnordered() && "implied by above"); V1->setAlignment(Align); V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); @@ -1436,6 +1440,12 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { } } + // If we have a store to a location which is known constant, we can conclude + // that the store must be storing the constant value (else the memory + // wouldn't be constant), and this must be a noop. + if (AA->pointsToConstantMemory(Ptr)) + return eraseInstFromFunction(SI); + // Do really simple DSE, to catch cases where there are several consecutive // stores to the same location, separated by a few arithmetic operations. This // situation often occurs with bitfield accesses. diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 7e99f3e4e500..cc753ce05313 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1,9 +1,8 @@ //===- InstCombineMulDivRem.cpp -------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -375,11 +374,13 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); // Sink negation: -X * Y --> -(X * Y) - if (match(Op0, m_OneUse(m_FNeg(m_Value(X))))) + // But don't transform constant expressions because there's an inverse fold. + if (match(Op0, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op0)) return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); // Sink negation: Y * -X --> -(X * Y) - if (match(Op1, m_OneUse(m_FNeg(m_Value(X))))) + // But don't transform constant expressions because there's an inverse fold. + if (match(Op1, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op1)) return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); // fabs(X) * fabs(X) -> X * X @@ -431,6 +432,14 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + Value *Z; + if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + // Sink division: (X / Y) * Z --> (X * Z) / Y + Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); + return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); + } + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) // nnan disallows the possibility of returning a number if both operands are // negative (in that case, we should return NaN). @@ -442,6 +451,45 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return replaceInstUsesWith(I, Sqrt); } + // Like the similar transform in instsimplify, this requires 'nsz' because + // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. + if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && + Op0->hasNUses(2)) { + // Peek through fdiv to find squaring of square root: + // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y + if (match(Op0, m_FDiv(m_Value(X), + m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(XX, Y, &I); + } + // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) + if (match(Op0, m_FDiv(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y)), + m_Value(X)))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(Y, XX, &I); + } + } + + // exp(X) * exp(Y) -> exp(X + Y) + // Match as long as at least one of exp has only one use. + if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); + return replaceInstUsesWith(I, Exp); + } + + // exp2(X) * exp2(Y) -> exp2(X + Y) + // Match as long as at least one of exp2 has only one use. + if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); + return replaceInstUsesWith(I, Exp2); + } + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). @@ -576,7 +624,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) return false; - APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt Remainder(C1.getBitWidth(), /*val=*/0ULL, IsSigned); if (IsSigned) APInt::sdivrem(C1, C2, Quotient, Remainder); else @@ -613,7 +661,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { // (X / C1) / C2 -> X / (C1*C2) if ((IsSigned && match(Op0, m_SDiv(m_Value(X), m_APInt(C1)))) || (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_APInt(C1))))) { - APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt Product(C1->getBitWidth(), /*val=*/0ULL, IsSigned); if (!multiplyOverflows(*C1, *C2, Product, IsSigned)) return BinaryOperator::Create(I.getOpcode(), X, ConstantInt::get(Ty, Product)); @@ -621,7 +669,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. if (isMultiple(*C2, *C1, Quotient, IsSigned)) { @@ -645,7 +693,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && *C1 != C1->getBitWidth() - 1) || (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); @@ -977,6 +1025,10 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return BinaryOperator::CreateNeg(Op0); + // X / INT_MIN --> X == INT_MIN + if (match(Op1, m_SignMask())) + return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType()); + const APInt *Op1C; if (match(Op1, m_APInt(Op1C))) { // sdiv exact X, C --> ashr exact X, log2(C) @@ -1001,22 +1053,25 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Value *NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor); return new SExtInst(NarrowOp, Op0->getType()); } - } - if (Constant *RHS = dyn_cast<Constant>(Op1)) { - // X/INT_MIN -> X == INT_MIN - if (RHS->isMinSignedValue()) - return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType()); - - // -X/C --> X/-C provided the negation doesn't overflow. - Value *X; - if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { - auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS)); + // -X / C --> X / -C (if the negation doesn't overflow). + // TODO: This could be enhanced to handle arbitrary vector constants by + // checking if all elements are not the min-signed-val. + if (!Op1C->isMinSignedValue() && + match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + Constant *NegC = ConstantInt::get(I.getType(), -(*Op1C)); + Instruction *BO = BinaryOperator::CreateSDiv(X, NegC); BO->setIsExact(I.isExact()); return BO; } } + // -X / Y --> -(X / Y) + Value *Y; + if (match(&I, m_SDiv(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNSWNeg( + Builder.CreateSDiv(X, Y, I.getName(), I.isExact())); + // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); @@ -1161,7 +1216,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { IRBuilder<> B(&I); IRBuilder<>::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(I.getFastMathFlags()); - AttributeList Attrs = CallSite(Op0).getCalledFunction()->getAttributes(); + AttributeList Attrs = + cast<CallBase>(Op0)->getCalledFunction()->getAttributes(); Value *Res = emitUnaryFloatFnCall(X, &TLI, LibFunc_tan, LibFunc_tanf, LibFunc_tanl, B, Attrs); if (IsCot) @@ -1305,6 +1361,11 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { } } + // -X srem Y --> -(X srem Y) + Value *X, *Y; + if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); + // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 7603cf4d7958..5820ab726637 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -1,9 +1,8 @@ //===- InstCombinePHI.cpp -------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -596,7 +595,8 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { Value *InVal = FirstLI->getOperand(0); NewPN->addIncoming(InVal, PN.getIncomingBlock(0)); - LoadInst *NewLI = new LoadInst(NewPN, "", isVolatile, LoadAlignment); + LoadInst *NewLI = + new LoadInst(FirstLI->getType(), NewPN, "", isVolatile, LoadAlignment); unsigned KnownIDs[] = { LLVMContext::MD_tbaa, @@ -1004,6 +1004,11 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { !isa<ConstantInt>(UserI->getOperand(1))) return nullptr; + // Bail on out of range shifts. + unsigned SizeInBits = UserI->getType()->getScalarSizeInBits(); + if (cast<ConstantInt>(UserI->getOperand(1))->getValue().uge(SizeInBits)) + return nullptr; + unsigned Shift = cast<ConstantInt>(UserI->getOperand(1))->getZExtValue(); PHIUsers.push_back(PHIUsageRecord(PHIId, Shift, UserI->user_back())); } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index faf58a08976d..aefaf5af1750 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1,9 +1,8 @@ //===- InstCombineSelect.cpp ----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -293,6 +292,8 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return nullptr; // If this is a cast from the same type, merge. + Value *Cond = SI.getCondition(); + Type *CondTy = Cond->getType(); if (TI->getNumOperands() == 1 && TI->isCast()) { Type *FIOpndTy = FI->getOperand(0)->getType(); if (TI->getOperand(0)->getType() != FIOpndTy) @@ -300,7 +301,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // The select condition may be a vector. We may only change the operand // type if the vector width remains the same (and matches the condition). - Type *CondTy = SI.getCondition()->getType(); if (CondTy->isVectorTy()) { if (!FIOpndTy->isVectorTy()) return nullptr; @@ -327,12 +327,24 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // Fold this by inserting a select from the input values. Value *NewSI = - Builder.CreateSelect(SI.getCondition(), TI->getOperand(0), - FI->getOperand(0), SI.getName() + ".v", &SI); + Builder.CreateSelect(Cond, TI->getOperand(0), FI->getOperand(0), + SI.getName() + ".v", &SI); return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, TI->getType()); } + // Cond ? -X : -Y --> -(Cond ? X : Y) + Value *X, *Y; + if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && + (TI->hasOneUse() || FI->hasOneUse())) { + Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); + // TODO: Remove the hack for the binop form when the unary op is optimized + // properly with all IR passes. + if (TI->getOpcode() != Instruction::FNeg) + return BinaryOperator::CreateFNegFMF(NewSel, cast<BinaryOperator>(TI)); + return UnaryOperator::CreateFNeg(NewSel); + } + // Only handle binary operators (including two-operand getelementptr) with // one-use here. As with the cast case above, it may be possible to relax the // one-use constraint, but that needs be examined carefully since it may not @@ -374,13 +386,12 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // If the select condition is a vector, the operands of the original select's // operands also must be vectors. This may not be the case for getelementptr // for example. - if (SI.getCondition()->getType()->isVectorTy() && - (!OtherOpT->getType()->isVectorTy() || - !OtherOpF->getType()->isVectorTy())) + if (CondTy->isVectorTy() && (!OtherOpT->getType()->isVectorTy() || + !OtherOpF->getType()->isVectorTy())) return nullptr; // If we reach here, they do have operations in common. - Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, + Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; @@ -521,6 +532,46 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, } /// We want to turn: +/// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 +/// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 +/// into: +/// ashr (X, Y) +static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = IC->getPredicate(); + Value *CmpLHS = IC->getOperand(0); + Value *CmpRHS = IC->getOperand(1); + if (!CmpRHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + Value *X, *Y; + unsigned Bitwidth = CmpRHS->getType()->getScalarSizeInBits(); + if ((Pred != ICmpInst::ICMP_SGT || + !match(CmpRHS, + m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, -1)))) && + (Pred != ICmpInst::ICMP_SLT || + !match(CmpRHS, + m_SpecificInt_ICMP(ICmpInst::ICMP_SGE, APInt(Bitwidth, 0))))) + return nullptr; + + // Canonicalize so that ashr is in FalseVal. + if (Pred == ICmpInst::ICMP_SLT) + std::swap(TrueVal, FalseVal); + + if (match(TrueVal, m_LShr(m_Value(X), m_Value(Y))) && + match(FalseVal, m_AShr(m_Specific(X), m_Specific(Y))) && + match(CmpLHS, m_Specific(X))) { + const auto *Ashr = cast<Instruction>(FalseVal); + // if lshr is not exact and ashr is, this new ashr must not be exact. + bool IsExact = Ashr->isExact() && cast<Instruction>(TrueVal)->isExact(); + return Builder.CreateAShr(X, Y, IC->getName(), IsExact); + } + + return nullptr; +} + +/// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: /// (or (shl (and X, C1), C3), Y) @@ -623,11 +674,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, return Builder.CreateOr(V, Y); } -/// Transform patterns such as: (a > b) ? a - b : 0 -/// into: ((a > b) ? a : b) - b) -/// This produces a canonical max pattern that is more easily recognized by the -/// backend and converted into saturated subtraction instructions if those -/// exist. +/// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, @@ -669,11 +716,73 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, if (!TrueVal->hasOneUse()) return nullptr; - // All checks passed, convert to canonical unsigned saturated subtraction - // form: sub(max()). - // (a > b) ? a - b : 0 -> ((a > b) ? a : b) - b) - Value *Max = Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); - return IsNegative ? Builder.CreateSub(B, Max) : Builder.CreateSub(Max, B); + // (a > b) ? a - b : 0 -> usub.sat(a, b) + // (a > b) ? b - a : 0 -> -usub.sat(a, b) + Value *Result = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, B); + if (IsNegative) + Result = Builder.CreateNeg(Result); + return Result; +} + +static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!Cmp->hasOneUse()) + return nullptr; + + // Match unsigned saturated add with constant. + Value *Cmp0 = Cmp->getOperand(0); + Value *Cmp1 = Cmp->getOperand(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); + Value *X; + const APInt *C, *CmpC; + if (Pred == ICmpInst::ICMP_ULT && + match(TVal, m_Add(m_Value(X), m_APInt(C))) && X == Cmp0 && + match(FVal, m_AllOnes()) && match(Cmp1, m_APInt(CmpC)) && *CmpC == ~*C) { + // (X u< ~C) ? (X + C) : -1 --> uadd.sat(X, C) + return Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_sat, X, ConstantInt::get(X->getType(), *C)); + } + + // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + // There are 8 commuted variants. + // Canonicalize -1 (saturated result) to true value of the select. Just + // swapping the compare operands is legal, because the selected value is the + // same in case of equality, so we can interchange u< and u<=. + if (match(FVal, m_AllOnes())) { + std::swap(TVal, FVal); + std::swap(Cmp0, Cmp1); + } + if (!match(TVal, m_AllOnes())) + return nullptr; + + // Canonicalize predicate to 'ULT'. + if (Pred == ICmpInst::ICMP_UGT) { + Pred = ICmpInst::ICMP_ULT; + std::swap(Cmp0, Cmp1); + } + if (Pred != ICmpInst::ICMP_ULT) + return nullptr; + + // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + Value *Y; + if (match(Cmp0, m_Not(m_Value(X))) && + match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) { + // (~X u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) + // (~X u< Y) ? -1 : (Y + X) --> uadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y); + } + // The 'not' op may be included in the sum but not the compare. + X = Cmp0; + Y = Cmp1; + if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) { + // (X u< Y) ? -1 : (~X + Y) --> uadd.sat(~X, Y) + // (X u< Y) ? -1 : (Y + ~X) --> uadd.sat(Y, ~X) + BinaryOperator *BO = cast<BinaryOperator>(FVal); + return Builder.CreateBinaryIntrinsic( + Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); + } + + return nullptr; } /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single @@ -1043,12 +1152,18 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -1496,6 +1611,43 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } +static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, + Value *Y, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected min/max pattern"); + bool IsUnsigned = SPF == SelectPatternFlavor::SPF_UMIN || + SPF == SelectPatternFlavor::SPF_UMAX; + // TODO: If InstSimplify could fold all cases where C2 <= C1, we could change + // the constant value check to an assert. + Value *A; + const APInt *C1, *C2; + if (IsUnsigned && match(X, m_NUWAdd(m_Value(A), m_APInt(C1))) && + match(Y, m_APInt(C2)) && C2->uge(*C1) && X->hasNUses(2)) { + // umin (add nuw A, C1), C2 --> add nuw (umin A, C2 - C1), C1 + // umax (add nuw A, C1), C2 --> add nuw (umax A, C2 - C1), C1 + Value *NewMinMax = createMinMax(Builder, SPF, A, + ConstantInt::get(X->getType(), *C2 - *C1)); + return BinaryOperator::CreateNUW(BinaryOperator::Add, NewMinMax, + ConstantInt::get(X->getType(), *C1)); + } + + if (!IsUnsigned && match(X, m_NSWAdd(m_Value(A), m_APInt(C1))) && + match(Y, m_APInt(C2)) && X->hasNUses(2)) { + bool Overflow; + APInt Diff = C2->ssub_ov(*C1, Overflow); + if (!Overflow) { + // smin (add nsw A, C1), C2 --> add nsw (smin A, C2 - C1), C1 + // smax (add nsw A, C1), C2 --> add nsw (smax A, C2 - C1), C1 + Value *NewMinMax = createMinMax(Builder, SPF, A, + ConstantInt::get(X->getType(), Diff)); + return BinaryOperator::CreateNSW(BinaryOperator::Add, NewMinMax, + ConstantInt::get(X->getType(), *C1)); + } + } + + return nullptr; +} + /// Reduce a sequence of min/max with a common operand. static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, Value *RHS, @@ -1757,37 +1909,55 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } + } - // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need - // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We - // also require nnan because we do not want to unintentionally change the - // sign of a NaN value. - Value *X = FCI->getOperand(0); - FCmpInst::Predicate Pred = FCI->getPredicate(); - if (match(FCI->getOperand(1), m_AnyZeroFP()) && FCI->hasNoNaNs()) { - // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) - // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) - if ((X == FalseVal && Pred == FCmpInst::FCMP_OLE && - match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || - (X == TrueVal && Pred == FCmpInst::FCMP_OGT && - match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); - return replaceInstUsesWith(SI, Fabs); - } - // With nsz: - // (X < +/-0.0) ? -X : X --> fabs(X) - // (X <= +/-0.0) ? -X : X --> fabs(X) - // (X > +/-0.0) ? X : -X --> fabs(X) - // (X >= +/-0.0) ? X : -X --> fabs(X) - if (FCI->hasNoSignedZeros() && - ((X == FalseVal && match(TrueVal, m_FNeg(m_Specific(X))) && - (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || - (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && - (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); - return replaceInstUsesWith(SI, Fabs); - } - } + // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We + // also require nnan because we do not want to unintentionally change the + // sign of a NaN value. + // FIXME: These folds should test/propagate FMF from the select, not the + // fsub or fneg. + // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) + Instruction *FSub; + if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && + match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(FalseVal))) && + match(TrueVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && + (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, FSub); + return replaceInstUsesWith(SI, Fabs); + } + // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) + if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && + match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(TrueVal))) && + match(FalseVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && + (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, FSub); + return replaceInstUsesWith(SI, Fabs); + } + // With nnan and nsz: + // (X < +/-0.0) ? -X : X --> fabs(X) + // (X <= +/-0.0) ? -X : X --> fabs(X) + Instruction *FNeg; + if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && + match(TrueVal, m_FNeg(m_Specific(FalseVal))) && + match(TrueVal, m_Instruction(FNeg)) && + FNeg->hasNoNaNs() && FNeg->hasNoSignedZeros() && + (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || + Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE)) { + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, FNeg); + return replaceInstUsesWith(SI, Fabs); + } + // With nnan and nsz: + // (X > +/-0.0) ? X : -X --> fabs(X) + // (X >= +/-0.0) ? X : -X --> fabs(X) + if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && + match(FalseVal, m_FNeg(m_Specific(TrueVal))) && + match(FalseVal, m_Instruction(FNeg)) && + FNeg->hasNoNaNs() && FNeg->hasNoSignedZeros() && + (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || + Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE)) { + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, FNeg); + return replaceInstUsesWith(SI, Fabs); } // See if we are selecting two values based on a comparison of the two values. @@ -1895,11 +2065,27 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = moveNotAfterMinMax(RHS, LHS)) return I; + if (Instruction *I = moveAddAfterMinMax(SPF, LHS, RHS, Builder)) + return I; + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) return I; } } + // Canonicalize select of FP values where NaN and -0.0 are not valid as + // minnum/maxnum intrinsics. + if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + Value *X, *Y; + if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); + + if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); + } + // See if we can fold the select into a phi node if the condition is a select. if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index c562d45a9e2b..c821292400cd 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1,9 +1,8 @@ //===- InstCombineShifts.cpp ----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,51 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +// Given pattern: +// (x shiftopcode Q) shiftopcode K +// we should rewrite it as +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) +// This is valid for any shift, but they must be identical. +static Instruction * +reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, + const SimplifyQuery &SQ) { + // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1 + Value *X, *ShAmt1, *ShAmt0; + Instruction *Sh1; + if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)), + m_Instruction(Sh1)), + m_Value(ShAmt0)))) + return nullptr; + + // The shift opcodes must be identical. + Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); + if (ShiftOpcode != Sh1->getOpcode()) + return nullptr; + // Can we fold (ShAmt0+ShAmt1) ? + Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1, + SQ.getWithInstruction(Sh0)); + if (!NewShAmt) + return nullptr; // Did not simplify. + // Is the new shift amount smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. + BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); + // If both of the original shifts had the same flag set, preserve the flag. + if (ShiftOpcode == Instruction::BinaryOps::Shl) { + NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && + Sh1->hasNoUnsignedWrap()); + NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && + Sh1->hasNoSignedWrap()); + } else { + NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + } + return NewShift; +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -39,6 +83,10 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; + if (Instruction *NewShift = + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)) + return NewShift; + // (C1 shift (A add C2)) -> (C1 shift C2) shift A) // iff A and C2 are both positive. Value *A; @@ -313,35 +361,17 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // If this is a bitwise operator or add with a constant RHS we might be able // to pull it through a shift. static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, - BinaryOperator *BO, - const APInt &C) { - bool IsValid = true; // Valid only for And, Or Xor, - bool HighBitSet = false; // Transform ifhigh bit of constant set? - + BinaryOperator *BO) { switch (BO->getOpcode()) { - default: IsValid = false; break; // Do not perform transform! + default: + return false; // Do not perform transform! case Instruction::Add: - IsValid = Shift.getOpcode() == Instruction::Shl; - break; + return Shift.getOpcode() == Instruction::Shl; case Instruction::Or: case Instruction::Xor: - HighBitSet = false; - break; case Instruction::And: - HighBitSet = true; - break; + return true; } - - // If this is a signed shift right, and the high bit is modified - // by the logical operation, do not perform the transformation. - // The HighBitSet boolean indicates the value of the high bit of - // the constant which would cause it to be modified for this - // operation. - // - if (IsValid && Shift.getOpcode() == Instruction::AShr) - IsValid = C.isNegative() == HighBitSet; - - return IsValid; } Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, @@ -508,7 +538,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // shift is the only use, we can pull it out of the shift. const APInt *Op0C; if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { - if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); @@ -552,7 +582,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, const APInt *C; if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && match(TBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, TBO, *C)) { + canShiftBinOpWithConstantRHS(I, TBO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); @@ -571,7 +601,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, const APInt *C; if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && match(FBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, FBO, *C)) { + canShiftBinOpWithConstantRHS(I, FBO)) { Constant *NewRHS = ConstantExpr::get(I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); @@ -601,6 +631,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); @@ -689,6 +721,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); } + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op0, m_One()) && + match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 9bf87d024607..e0d85c4b49ae 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1,9 +1,8 @@ //===- InstCombineSimplifyDemanded.cpp ------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -366,10 +365,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits InputKnown(SrcBitWidth); if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) return I; - Known = InputKnown.zextOrTrunc(BitWidth); - // Any top bits are known to be zero. - if (BitWidth > SrcBitWidth) - Known.Zero.setBitsFrom(SrcBitWidth); + assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); + Known = InputKnown.zextOrTrunc(BitWidth, + true /* ExtendedBitsAreKnownZero */); assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } @@ -967,26 +965,16 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, } /// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics. +/// +/// Note: This only supports non-TFE/LWE image intrinsic calls; those have +/// struct returns. Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, - int DMaskIdx, - int TFCIdx) { + int DMaskIdx) { unsigned VWidth = II->getType()->getVectorNumElements(); if (VWidth == 1) return nullptr; - // Need to change to new instruction format - ConstantInt *TFC = nullptr; - bool TFELWEEnabled = false; - if (TFCIdx > 0) { - TFC = dyn_cast<ConstantInt>(II->getArgOperand(TFCIdx)); - TFELWEEnabled = TFC->getZExtValue() & 0x1 // TFE - || TFC->getZExtValue() & 0x2; // LWE - } - - if (TFELWEEnabled) - return nullptr; // TFE not yet supported - ConstantInt *NewDMask = nullptr; if (DMaskIdx < 0) { @@ -994,10 +982,7 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, // below. DemandedElts = (1 << DemandedElts.getActiveBits()) - 1; } else { - ConstantInt *DMask = dyn_cast<ConstantInt>(II->getArgOperand(DMaskIdx)); - if (!DMask) - return nullptr; // non-constant dmask is not supported by codegen - + ConstantInt *DMask = cast<ConstantInt>(II->getArgOperand(DMaskIdx)); unsigned DMaskVal = DMask->getZExtValue() & 0xf; // Mask off values that are undefined because the dmask doesn't cover them @@ -1018,8 +1003,7 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, NewDMask = ConstantInt::get(DMask->getType(), NewDMaskVal); } - // TODO: Handle 3 vectors when supported in code gen. - unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countPopulation()); + unsigned NewNumElts = DemandedElts.countPopulation(); if (!NewNumElts) return UndefValue::get(II->getType()); @@ -1035,13 +1019,12 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, getIntrinsicInfoTableEntries(IID, Table); ArrayRef<Intrinsic::IITDescriptor> TableRef = Table; + // Validate function argument and return types, extracting overloaded types + // along the way. FunctionType *FTy = II->getCalledFunction()->getFunctionType(); SmallVector<Type *, 6> OverloadTys; - Intrinsic::matchIntrinsicType(FTy->getReturnType(), TableRef, OverloadTys); - for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) - Intrinsic::matchIntrinsicType(FTy->getParamType(i), TableRef, OverloadTys); + Intrinsic::matchIntrinsicSignature(FTy, TableRef, OverloadTys); - // Get the new return type overload of the intrinsic. Module *M = II->getParent()->getParent()->getParent(); Type *EltTy = II->getType()->getVectorElementType(); Type *NewTy = (NewNumElts == 1) ? EltTy : VectorType::get(EltTy, NewNumElts); @@ -1184,6 +1167,39 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (I->getOpcode()) { default: break; + case Instruction::GetElementPtr: { + // The LangRef requires that struct geps have all constant indices. As + // such, we can't convert any operand to partial undef. + auto mayIndexStructType = [](GetElementPtrInst &GEP) { + for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP); + I != E; I++) + if (I.isStruct()) + return true;; + return false; + }; + if (mayIndexStructType(cast<GetElementPtrInst>(*I))) + break; + + // Conservatively track the demanded elements back through any vector + // operands we may have. We know there must be at least one, or we + // wouldn't have a vector result to get here. Note that we intentionally + // merge the undef bits here since gepping with either an undef base or + // index results in undef. + for (unsigned i = 0; i < I->getNumOperands(); i++) { + if (isa<UndefValue>(I->getOperand(i))) { + // If the entire vector is undefined, just return this info. + UndefElts = EltMask; + return nullptr; + } + if (I->getOperand(i)->getType()->isVectorTy()) { + APInt UndefEltsOp(VWidth, 0); + simplifyAndSetOp(I, i, DemandedElts, UndefEltsOp); + UndefElts |= UndefEltsOp; + } + } + + break; + } case Instruction::InsertElement: { // If this is a variable index, we don't know which element it overwrites. // demand exactly the same input as we produce. @@ -1430,6 +1446,30 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); if (!II) break; switch (II->getIntrinsicID()) { + case Intrinsic::masked_gather: // fallthrough + case Intrinsic::masked_load: { + // Subtlety: If we load from a pointer, the pointer must be valid + // regardless of whether the element is demanded. Doing otherwise risks + // segfaults which didn't exist in the original program. + APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)), + DemandedPassThrough(DemandedElts); + if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) + for (unsigned i = 0; i < VWidth; i++) { + Constant *CElt = CV->getAggregateElement(i); + if (CElt->isNullValue()) + DemandedPtrs.clearBit(i); + else if (CElt->isAllOnesValue()) + DemandedPassThrough.clearBit(i); + } + if (II->getIntrinsicID() == Intrinsic::masked_gather) + simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2); + simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3); + + // Output elements are undefined if the element from both sources are. + // TODO: can strengthen via mask as well. + UndefElts = UndefElts2 & UndefElts3; + break; + } case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: // The instructions for these intrinsics are speced to zero upper bits not @@ -1639,8 +1679,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); default: { if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) - return simplifyAMDGCNMemoryIntrinsicDemanded( - II, DemandedElts, 0, II->getNumArgOperands() - 2); + return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0); break; } @@ -1667,5 +1706,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, UndefElts &= UndefElts2; } + // If we've proven all of the lanes undef, return an undef value. + // TODO: Intersect w/demanded lanes + if (UndefElts.isAllOnesValue()) + return UndefValue::get(I->getType());; + return MadeChange ? I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 0ad1fc0e791f..dc9abdd7f47a 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1,9 +1,8 @@ //===- InstCombineVectorOps.cpp -------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -663,18 +662,17 @@ static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { return true; } -// Turn a chain of inserts that splats a value into a canonical insert + shuffle -// splat. That is: -// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> -// shufflevector(insertelt(X, %k, 0), undef, zero) -static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { - // We are interested in the last insert in a chain. So, if this insert - // has a single user, and that user is an insert, bail. +/// Turn a chain of inserts that splats a value into an insert + shuffle: +/// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> +/// shufflevector(insertelt(X, %k, 0), undef, zero) +static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { + // We are interested in the last insert in a chain. So if this insert has a + // single user and that user is an insert, bail. if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) return nullptr; - VectorType *VT = cast<VectorType>(InsElt.getType()); - int NumElements = VT->getNumElements(); + auto *VecTy = cast<VectorType>(InsElt.getType()); + unsigned NumElements = VecTy->getNumElements(); // Do not try to do this for a one-element vector, since that's a nop, // and will cause an inf-loop. @@ -706,24 +704,66 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { CurrIE = NextIE; } - // Make sure we've seen an insert into every element. - if (llvm::any_of(ElementPresent, [](bool Present) { return !Present; })) + // If this is just a single insertelement (not a sequence), we are done. + if (FirstIE == &InsElt) return nullptr; - // All right, create the insert + shuffle. - Instruction *InsertFirst; - if (cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) - InsertFirst = FirstIE; - else - InsertFirst = InsertElementInst::Create( - UndefValue::get(VT), SplatVal, - ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), - "", &InsElt); + // If we are not inserting into an undef vector, make sure we've seen an + // insert into every element. + // TODO: If the base vector is not undef, it might be better to create a splat + // and then a select-shuffle (blend) with the base vector. + if (!isa<UndefValue>(FirstIE->getOperand(0))) + if (any_of(ElementPresent, [](bool Present) { return !Present; })) + return nullptr; + + // Create the insert + shuffle. + Type *Int32Ty = Type::getInt32Ty(InsElt.getContext()); + UndefValue *UndefVec = UndefValue::get(VecTy); + Constant *Zero = ConstantInt::get(Int32Ty, 0); + if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) + FirstIE = InsertElementInst::Create(UndefVec, SplatVal, Zero, "", &InsElt); - Constant *ZeroMask = ConstantAggregateZero::get( - VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); + // Splat from element 0, but replace absent elements with undef in the mask. + SmallVector<Constant *, 16> Mask(NumElements, Zero); + for (unsigned i = 0; i != NumElements; ++i) + if (!ElementPresent[i]) + Mask[i] = UndefValue::get(Int32Ty); - return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); + return new ShuffleVectorInst(FirstIE, UndefVec, ConstantVector::get(Mask)); +} + +/// Try to fold an insert element into an existing splat shuffle by changing +/// the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is a canonical splat shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !Shuf->isZeroEltSplat()) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if the splat shuffle's input is the same as this insert's scalar op. + Value *X = InsElt.getOperand(1); + Value *Op0 = Shuf->getOperand(0); + if (!match(Op0, m_InsertElement(m_Undef(), m_Specific(X), m_ZeroInt()))) + return nullptr; + + // Replace the shuffle mask element at the index of this insert with a zero. + // For example: + // inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1 + // --> shuf (inselt undef, X, 0), undef, <0,0,0,undef> + unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); + Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); + Constant *Zero = ConstantInt::getNullValue(I32Ty); + for (unsigned i = 0; i != NumMaskElts; ++i) + NewMaskVec[i] = i == IdxC ? Zero : Shuf->getMask()->getAggregateElement(i); + + Constant *NewMask = ConstantVector::get(NewMaskVec); + return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); } /// If we have an insertelement instruction feeding into another insertelement @@ -864,30 +904,28 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) return replaceInstUsesWith(IE, V); - // Inserting an undef or into an undefined place, remove this. - if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) - replaceInstUsesWith(IE, VecOp); + // If the vector and scalar are both bitcast from the same element type, do + // the insert in that source type followed by bitcast. + Value *VecSrc, *ScalarSrc; + if (match(VecOp, m_BitCast(m_Value(VecSrc))) && + match(ScalarOp, m_BitCast(m_Value(ScalarSrc))) && + (VecOp->hasOneUse() || ScalarOp->hasOneUse()) && + VecSrc->getType()->isVectorTy() && !ScalarSrc->getType()->isVectorTy() && + VecSrc->getType()->getVectorElementType() == ScalarSrc->getType()) { + // inselt (bitcast VecSrc), (bitcast ScalarSrc), IdxOp --> + // bitcast (inselt VecSrc, ScalarSrc, IdxOp) + Value *NewInsElt = Builder.CreateInsertElement(VecSrc, ScalarSrc, IdxOp); + return new BitCastInst(NewInsElt, IE.getType()); + } // If the inserted element was extracted from some other vector and both - // indexes are constant, try to turn this into a shuffle. + // indexes are valid constants, try to turn this into a shuffle. uint64_t InsertedIdx, ExtractedIdx; Value *ExtVecOp; if (match(IdxOp, m_ConstantInt(InsertedIdx)) && match(ScalarOp, m_ExtractElement(m_Value(ExtVecOp), - m_ConstantInt(ExtractedIdx)))) { - unsigned NumInsertVectorElts = IE.getType()->getNumElements(); - unsigned NumExtractVectorElts = ExtVecOp->getType()->getVectorNumElements(); - if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. - return replaceInstUsesWith(IE, VecOp); - - if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. - return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); - - // If we are extracting a value from a vector, then inserting it right - // back into the same place, just use the input vector. - if (ExtVecOp == VecOp && ExtractedIdx == InsertedIdx) - return replaceInstUsesWith(IE, VecOp); - + m_ConstantInt(ExtractedIdx))) && + ExtractedIdx < ExtVecOp->getType()->getVectorNumElements()) { // TODO: Looking at the user(s) to determine if this insert is a // fold-to-shuffle opportunity does not match the usual instcombine // constraints. We should decide if the transform is worthy based only @@ -943,11 +981,12 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *NewInsElt = hoistInsEltConst(IE, Builder)) return NewInsElt; - // Turn a sequence of inserts that broadcasts a scalar into a single - // insert + shufflevector. - if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) + if (Instruction *Broadcast = foldInsSequenceIntoSplat(IE)) return Broadcast; + if (Instruction *Splat = foldInsEltIntoSplat(IE)) + return Splat; + return nullptr; } @@ -1172,7 +1211,14 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { SmallVector<Value*, 8> NewOps; bool NeedsRebuild = (Mask.size() != I->getType()->getVectorNumElements()); for (int i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); + Value *V; + // Recursively call evaluateInDifferentElementOrder on vector arguments + // as well. E.g. GetElementPtr may have scalar operands even if the + // return value is a vector, so we need to examine the operand type. + if (I->getOperand(i)->getType()->isVectorTy()) + V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); + else + V = I->getOperand(i); NewOps.push_back(V); NeedsRebuild |= (V != I->getOperand(i)); } @@ -1337,6 +1383,41 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { return NewBO; } +/// If we have an insert of a scalar to a non-zero element of an undefined +/// vector and then shuffle that value, that's the same as inserting to the zero +/// element and shuffling. Splatting from the zero element is recognized as the +/// canonical form of splat. +static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + Constant *Mask = Shuf.getMask(); + Value *X; + uint64_t IndexC; + + // Match a shuffle that is a splat to a non-zero element. + if (!match(Op0, m_OneUse(m_InsertElement(m_Undef(), m_Value(X), + m_ConstantInt(IndexC)))) || + !match(Op1, m_Undef()) || match(Mask, m_ZeroInt()) || IndexC == 0) + return nullptr; + + // Insert into element 0 of an undef vector. + UndefValue *UndefVec = UndefValue::get(Shuf.getType()); + Constant *Zero = Builder.getInt32(0); + Value *NewIns = Builder.CreateInsertElement(UndefVec, X, Zero); + + // Splat from element 0. Any mask element that is undefined remains undefined. + // For example: + // shuf (inselt undef, X, 2), undef, <2,2,undef> + // --> shuf (inselt undef, X, 0), undef, <0,0,undef> + unsigned NumMaskElts = Shuf.getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMask(NumMaskElts, Zero); + for (unsigned i = 0; i != NumMaskElts; ++i) + if (isa<UndefValue>(Mask->getAggregateElement(i))) + NewMask[i] = Mask->getAggregateElement(i); + + return new ShuffleVectorInst(NewIns, UndefVec, ConstantVector::get(NewMask)); +} + /// Try to fold shuffles that are the equivalent of a vector select. static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, InstCombiner::BuilderTy &Builder, @@ -1344,6 +1425,15 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, if (!Shuf.isSelect()) return nullptr; + // Canonicalize to choose from operand 0 first. + unsigned NumElts = Shuf.getType()->getVectorNumElements(); + if (Shuf.getMaskValue(0) >= (int)NumElts) { + // TODO: Can we assert that both operands of a shuffle-select are not undef + // (otherwise, it would have been folded by instsimplify? + Shuf.commute(); + return &Shuf; + } + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) return I; @@ -1499,6 +1589,11 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { if (!match(Op0, m_ShuffleVector(m_Value(X), m_Value(Y), m_Constant(Mask)))) return nullptr; + // Be conservative with shuffle transforms. If we can't kill the 1st shuffle, + // then combining may result in worse codegen. + if (!Op0->hasOneUse()) + return nullptr; + // We are extracting a subvector from a shuffle. Remove excess elements from // the 1st shuffle mask to eliminate the extract. // @@ -1588,6 +1683,72 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { return nullptr; } +static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { + // Match the operands as identity with padding (also known as concatenation + // with undef) shuffles of the same source type. The backend is expected to + // recreate these concatenations from a shuffle of narrow operands. + auto *Shuffle0 = dyn_cast<ShuffleVectorInst>(Shuf.getOperand(0)); + auto *Shuffle1 = dyn_cast<ShuffleVectorInst>(Shuf.getOperand(1)); + if (!Shuffle0 || !Shuffle0->isIdentityWithPadding() || + !Shuffle1 || !Shuffle1->isIdentityWithPadding()) + return nullptr; + + // We limit this transform to power-of-2 types because we expect that the + // backend can convert the simplified IR patterns to identical nodes as the + // original IR. + // TODO: If we can verify the same behavior for arbitrary types, the + // power-of-2 checks can be removed. + Value *X = Shuffle0->getOperand(0); + Value *Y = Shuffle1->getOperand(0); + if (X->getType() != Y->getType() || + !isPowerOf2_32(Shuf.getType()->getVectorNumElements()) || + !isPowerOf2_32(Shuffle0->getType()->getVectorNumElements()) || + !isPowerOf2_32(X->getType()->getVectorNumElements()) || + isa<UndefValue>(X) || isa<UndefValue>(Y)) + return nullptr; + assert(isa<UndefValue>(Shuffle0->getOperand(1)) && + isa<UndefValue>(Shuffle1->getOperand(1)) && + "Unexpected operand for identity shuffle"); + + // This is a shuffle of 2 widening shuffles. We can shuffle the narrow source + // operands directly by adjusting the shuffle mask to account for the narrower + // types: + // shuf (widen X), (widen Y), Mask --> shuf X, Y, Mask' + int NarrowElts = X->getType()->getVectorNumElements(); + int WideElts = Shuffle0->getType()->getVectorNumElements(); + assert(WideElts > NarrowElts && "Unexpected types for identity with padding"); + + Type *I32Ty = IntegerType::getInt32Ty(Shuf.getContext()); + SmallVector<int, 16> Mask = Shuf.getShuffleMask(); + SmallVector<Constant *, 16> NewMask(Mask.size(), UndefValue::get(I32Ty)); + for (int i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == -1) + continue; + + // If this shuffle is choosing an undef element from 1 of the sources, that + // element is undef. + if (Mask[i] < WideElts) { + if (Shuffle0->getMaskValue(Mask[i]) == -1) + continue; + } else { + if (Shuffle1->getMaskValue(Mask[i] - WideElts) == -1) + continue; + } + + // If this shuffle is choosing from the 1st narrow op, the mask element is + // the same. If this shuffle is choosing from the 2nd narrow op, the mask + // element is offset down to adjust for the narrow vector widths. + if (Mask[i] < WideElts) { + assert(Mask[i] < NarrowElts && "Unexpected shuffle mask"); + NewMask[i] = ConstantInt::get(I32Ty, Mask[i]); + } else { + assert(Mask[i] < (WideElts + NarrowElts) && "Unexpected shuffle mask"); + NewMask[i] = ConstantInt::get(I32Ty, Mask[i] - (WideElts - NarrowElts)); + } + } + return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); +} + Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -1595,36 +1756,12 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); - if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) - return I; - - if (Instruction *I = narrowVectorSelect(SVI, Builder)) - return I; - + // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') + // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). unsigned VWidth = SVI.getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); - if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { - if (V != &SVI) - return replaceInstUsesWith(SVI, V); - return &SVI; - } - - if (Instruction *I = foldIdentityExtractShuffle(SVI)) - return I; - - // This transform has the potential to lose undef knowledge, so it is - // intentionally placed after SimplifyDemandedVectorElts(). - if (Instruction *I = foldShuffleWithInsert(SVI)) - return I; - + unsigned LHSWidth = LHS->getType()->getVectorNumElements(); SmallVector<int, 16> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); - unsigned LHSWidth = LHS->getType()->getVectorNumElements(); - bool MadeChange = false; - - // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') - // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). if (LHS == RHS || isa<UndefValue>(LHS)) { // Remap any references to RHS to use LHS. SmallVector<Constant*, 16> Elts; @@ -1646,11 +1783,36 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SVI.setOperand(0, SVI.getOperand(1)); SVI.setOperand(1, UndefValue::get(RHS->getType())); SVI.setOperand(2, ConstantVector::get(Elts)); - LHS = SVI.getOperand(0); - RHS = SVI.getOperand(1); - MadeChange = true; + return &SVI; } + if (Instruction *I = canonicalizeInsertSplat(SVI, Builder)) + return I; + + if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) + return I; + + if (Instruction *I = narrowVectorSelect(SVI, Builder)) + return I; + + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { + if (V != &SVI) + return replaceInstUsesWith(SVI, V); + return &SVI; + } + + if (Instruction *I = foldIdentityExtractShuffle(SVI)) + return I; + + // These transforms have the potential to lose undef knowledge, so they are + // intentionally placed after SimplifyDemandedVectorElts(). + if (Instruction *I = foldShuffleWithInsert(SVI)) + return I; + if (Instruction *I = foldIdentityPaddedShuffles(SVI)) + return I; + if (VWidth == LHSWidth) { // Analyze the shuffle, are the LHS or RHS and identity shuffles? bool isLHSID, isRHSID; @@ -1695,6 +1857,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // +-----------+-----------+-----------+-----------+ // Index range [6,10): ^-----------^ Needs an extra shuffle. // Target type i40: ^--------------^ Won't work, bail. + bool MadeChange = false; if (isShuffleExtractingFromLHS(SVI, Mask)) { Value *V = LHS; unsigned MaskElems = Mask.size(); diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index be7d43bbcf2c..385f4926b845 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1,9 +1,8 @@ //===- InstructionCombining.cpp - Combine multiple instructions -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -47,14 +46,17 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -221,6 +223,11 @@ static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { return !Overflow; } +static bool hasNoUnsignedWrap(BinaryOperator &I) { + OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + return OBO && OBO->hasNoUnsignedWrap(); +} + /// Conservatively clears subclassOptionalData after a reassociation or /// commutation. We preserve fast-math flags when applicable as they can be /// preserved. @@ -327,14 +334,19 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { I.setOperand(1, V); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. - if (MaintainNoSignedWrap(I, B, C) && + bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0); + bool IsNSW = MaintainNoSignedWrap(I, B, C); + + ClearSubclassDataAfterReassociation(I); + + if (IsNUW) + I.setHasNoUnsignedWrap(true); + + if (IsNSW && (!Op0 || (isa<BinaryOperator>(Op0) && Op0->hasNoSignedWrap()))) { // Note: this is only valid because SimplifyBinOp doesn't look at // the operands to Op0. - I.clearSubclassOptionalData(); I.setHasNoSignedWrap(true); - } else { - ClearSubclassDataAfterReassociation(I); } Changed = true; @@ -419,8 +431,14 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Op0->getOpcode() == Opcode && Op1->getOpcode() == Opcode && match(Op0, m_OneUse(m_BinOp(m_Value(A), m_Constant(C1)))) && match(Op1, m_OneUse(m_BinOp(m_Value(B), m_Constant(C2))))) { - BinaryOperator *NewBO = BinaryOperator::Create(Opcode, A, B); - if (isa<FPMathOperator>(NewBO)) { + bool IsNUW = hasNoUnsignedWrap(I) && + hasNoUnsignedWrap(*Op0) && + hasNoUnsignedWrap(*Op1); + BinaryOperator *NewBO = (IsNUW && Opcode == Instruction::Add) ? + BinaryOperator::CreateNUW(Opcode, A, B) : + BinaryOperator::Create(Opcode, A, B); + + if (isa<FPMathOperator>(NewBO)) { FastMathFlags Flags = I.getFastMathFlags(); Flags &= Op0->getFastMathFlags(); Flags &= Op1->getFastMathFlags(); @@ -433,6 +451,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); + if (IsNUW) + I.setHasNoUnsignedWrap(true); Changed = true; continue; @@ -570,32 +590,44 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, ++NumFactor; SimplifiedInst->takeName(&I); - // Check if we can add NSW flag to SimplifiedInst. If so, set NSW flag. - // TODO: Check for NUW. + // Check if we can add NSW/NUW flags to SimplifiedInst. If so, set them. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(SimplifiedInst)) { if (isa<OverflowingBinaryOperator>(SimplifiedInst)) { bool HasNSW = false; - if (isa<OverflowingBinaryOperator>(&I)) + bool HasNUW = false; + if (isa<OverflowingBinaryOperator>(&I)) { HasNSW = I.hasNoSignedWrap(); + HasNUW = I.hasNoUnsignedWrap(); + } - if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { HasNSW &= LOBO->hasNoSignedWrap(); + HasNUW &= LOBO->hasNoUnsignedWrap(); + } - if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { HasNSW &= ROBO->hasNoSignedWrap(); + HasNUW &= ROBO->hasNoUnsignedWrap(); + } - // We can propagate 'nsw' if we know that - // %Y = mul nsw i16 %X, C - // %Z = add nsw i16 %Y, %X - // => - // %Z = mul nsw i16 %X, C+1 - // - // iff C+1 isn't INT_MIN const APInt *CInt; if (TopLevelOpcode == Instruction::Add && - InnerOpcode == Instruction::Mul) - if (match(V, m_APInt(CInt)) && !CInt->isMinSignedValue()) - BO->setHasNoSignedWrap(HasNSW); + InnerOpcode == Instruction::Mul) { + // We can propagate 'nsw' if we know that + // %Y = mul nsw i16 %X, C + // %Z = add nsw i16 %Y, %X + // => + // %Z = mul nsw i16 %X, C+1 + // + // iff C+1 isn't INT_MIN + if (match(V, m_APInt(CInt))) { + if (!CInt->isMinSignedValue()) + BO->setHasNoSignedWrap(HasNSW); + } + + // nuw can be propagated with any constant or nuw value. + BO->setHasNoUnsignedWrap(HasNUW); + } } } } @@ -922,8 +954,8 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { // If the InVal is an invoke at the end of the pred block, then we can't // insert a computation after it without breaking the edge. - if (InvokeInst *II = dyn_cast<InvokeInst>(InVal)) - if (II->getParent() == NonConstBB) + if (isa<InvokeInst>(InVal)) + if (cast<Instruction>(InVal)->getParent() == NonConstBB) return nullptr; // If the incoming non-constant value is in I's block, we will remove one @@ -1376,7 +1408,8 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { if (match(LHS, m_ShuffleVector(m_Value(L0), m_Value(L1), m_Constant(Mask))) && match(RHS, m_ShuffleVector(m_Value(R0), m_Value(R1), m_Specific(Mask))) && LHS->hasOneUse() && RHS->hasOneUse() && - cast<ShuffleVectorInst>(LHS)->isConcat()) { + cast<ShuffleVectorInst>(LHS)->isConcat() && + cast<ShuffleVectorInst>(RHS)->isConcat()) { // This transform does not have the speculative execution constraint as // below because the shuffle is a concatenation. The new binops are // operating on exactly the same elements as the existing binop. @@ -1415,6 +1448,30 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { return createBinOpShuffle(V1, V2, Mask); } + // If both arguments of a commutative binop are select-shuffles that use the + // same mask with commuted operands, the shuffles are unnecessary. + if (Inst.isCommutative() && + match(LHS, m_ShuffleVector(m_Value(V1), m_Value(V2), m_Constant(Mask))) && + match(RHS, m_ShuffleVector(m_Specific(V2), m_Specific(V1), + m_Specific(Mask)))) { + auto *LShuf = cast<ShuffleVectorInst>(LHS); + auto *RShuf = cast<ShuffleVectorInst>(RHS); + // TODO: Allow shuffles that contain undefs in the mask? + // That is legal, but it reduces undef knowledge. + // TODO: Allow arbitrary shuffles by shuffling after binop? + // That might be legal, but we have to deal with poison. + if (LShuf->isSelect() && !LShuf->getMask()->containsUndefElement() && + RShuf->isSelect() && !RShuf->getMask()->containsUndefElement()) { + // Example: + // LHS = shuffle V1, V2, <0, 5, 6, 3> + // RHS = shuffle V2, V1, <0, 5, 6, 3> + // LHS + RHS --> (V10+V20, V21+V11, V22+V12, V13+V23) --> V1 + V2 + Instruction *NewBO = BinaryOperator::Create(Opcode, V1, V2); + NewBO->copyIRFlags(&Inst); + return NewBO; + } + } + // If one argument is a shuffle within one vector and the other is a constant, // try moving the shuffle after the binary operation. This canonicalization // intends to move shuffles closer to other shuffles and binops closer to @@ -1557,6 +1614,23 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); + // For vector geps, use the generic demanded vector support. + if (GEP.getType()->isVectorTy()) { + auto VWidth = GEP.getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&GEP, AllOnesEltMask, + UndefElts)) { + if (V != &GEP) + return replaceInstUsesWith(GEP, V); + return &GEP; + } + + // TODO: 1) Scalarize splat operands, 2) scalarize entire instruction if + // possible (decide on canonical form for pointer broadcast), 3) exploit + // undef elements to decrease demanded bits + } + Value *PtrOp = GEP.getOperand(0); // Eliminate unneeded casts for indices, and replace indices which displace @@ -1755,9 +1829,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // put NewSrc at same location as %src Builder.SetInsertPoint(cast<Instruction>(PtrOp)); auto *NewSrc = cast<GetElementPtrInst>( - Builder.CreateGEP(SO0, GO1, Src->getName())); + Builder.CreateGEP(GEPEltType, SO0, GO1, Src->getName())); NewSrc->setIsInBounds(Src->isInBounds()); - auto *NewGEP = GetElementPtrInst::Create(nullptr, NewSrc, {SO1}); + auto *NewGEP = GetElementPtrInst::Create(GEPEltType, NewSrc, {SO1}); NewGEP->setIsInBounds(GEP.isInBounds()); return NewGEP; } @@ -1881,6 +1955,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (StrippedPtr != PtrOp) { bool HasZeroPointerIndex = false; + Type *StrippedPtrEltTy = StrippedPtrTy->getElementType(); + if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) HasZeroPointerIndex = C->isZero(); @@ -1894,11 +1970,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (HasZeroPointerIndex) { if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? - if (CATy->getElementType() == StrippedPtrTy->getElementType()) { + if (CATy->getElementType() == StrippedPtrEltTy) { // -> GEP i8* X, ... SmallVector<Value*, 8> Idx(GEP.idx_begin()+1, GEP.idx_end()); GetElementPtrInst *Res = GetElementPtrInst::Create( - StrippedPtrTy->getElementType(), StrippedPtr, Idx, GEP.getName()); + StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()); Res->setIsInBounds(GEP.isInBounds()); if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) return Res; @@ -1911,7 +1987,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); } - if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrTy->getElementType())) { + if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrEltTy)) { // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == XATy->getElementType()) { // -> GEP [10 x i8]* X, i32 0, ... @@ -1934,11 +2010,12 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // %0 = GEP [10 x i8] addrspace(1)* X, ... // addrspacecast i8 addrspace(1)* %0 to i8* SmallVector<Value*, 8> Idx(GEP.idx_begin(), GEP.idx_end()); - Value *NewGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP( - nullptr, StrippedPtr, Idx, GEP.getName()) - : Builder.CreateGEP(nullptr, StrippedPtr, Idx, - GEP.getName()); + Value *NewGEP = + GEP.isInBounds() + ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, + Idx, GEP.getName()) + : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName()); return new AddrSpaceCastInst(NewGEP, GEPType); } } @@ -1947,17 +2024,17 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Transform things like: // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast - Type *SrcEltTy = StrippedPtrTy->getElementType(); - if (SrcEltTy->isArrayTy() && - DL.getTypeAllocSize(SrcEltTy->getArrayElementType()) == + if (StrippedPtrEltTy->isArrayTy() && + DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == DL.getTypeAllocSize(GEPEltType)) { Type *IdxType = DL.getIndexType(GEPType); Value *Idx[2] = { Constant::getNullValue(IdxType), GEP.getOperand(1) }; Value *NewGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(nullptr, StrippedPtr, Idx, + ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()) - : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); + : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName()); // V and GEP are both pointer types --> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); @@ -1967,11 +2044,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // %V = mul i64 %N, 4 // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast - if (GEPEltType->isSized() && SrcEltTy->isSized()) { + if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); - uint64_t SrcSize = DL.getTypeAllocSize(SrcEltTy); + uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -1990,9 +2067,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // GEP may not be "inbounds". Value *NewGEP = GEP.isInBounds() && NSW - ? Builder.CreateInBoundsGEP(nullptr, StrippedPtr, NewIdx, - GEP.getName()) - : Builder.CreateGEP(nullptr, StrippedPtr, NewIdx, + ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, + NewIdx, GEP.getName()) + : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast @@ -2006,13 +2083,13 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp // (where tmp = 8*tmp2) into: // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - if (GEPEltType->isSized() && SrcEltTy->isSized() && - SrcEltTy->isArrayTy()) { + if (GEPEltType->isSized() && StrippedPtrEltTy->isSized() && + StrippedPtrEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); uint64_t ArrayEltSize = - DL.getTypeAllocSize(SrcEltTy->getArrayElementType()); + DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2032,11 +2109,12 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Type *IndTy = DL.getIndexType(GEPType); Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; - Value *NewGEP = GEP.isInBounds() && NSW - ? Builder.CreateInBoundsGEP( - SrcEltTy, StrippedPtr, Off, GEP.getName()) - : Builder.CreateGEP(SrcEltTy, StrippedPtr, Off, - GEP.getName()); + Value *NewGEP = + GEP.isInBounds() && NSW + ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, + Off, GEP.getName()) + : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, + GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); @@ -2084,8 +2162,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // constructing an AddrSpaceCastInst Value *NGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(nullptr, SrcOp, {Ops[1], Ops[2]}) - : Builder.CreateGEP(nullptr, SrcOp, {Ops[1], Ops[2]}); + ? Builder.CreateInBoundsGEP(SrcEltType, SrcOp, {Ops[1], Ops[2]}) + : Builder.CreateGEP(SrcEltType, SrcOp, {Ops[1], Ops[2]}); NGEP->takeName(&GEP); // Preserve GEP address space to satisfy users @@ -2132,8 +2210,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (FindElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices)) { Value *NGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(nullptr, SrcOp, NewIndices) - : Builder.CreateGEP(nullptr, SrcOp, NewIndices); + ? Builder.CreateInBoundsGEP(SrcEltType, SrcOp, NewIndices) + : Builder.CreateGEP(SrcEltType, SrcOp, NewIndices); if (NGEP->getType() == GEPType) return replaceInstUsesWith(GEP, NGEP); @@ -2159,7 +2237,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { APInt AllocSize(IdxWidth, DL.getTypeAllocSize(AI->getAllocatedType())); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( - PtrOp, makeArrayRef(Ops).slice(1), GEP.getName()); + GEP.getSourceElementType(), PtrOp, makeArrayRef(Ops).slice(1), + GEP.getName()); } } } @@ -2296,8 +2375,8 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { - ConstantInt *Result = lowerObjectSizeCall(II, DL, &TLI, - /*MustSucceed=*/true); + Value *Result = + lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/true); replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. @@ -2426,9 +2505,8 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // free undef -> unreachable. if (isa<UndefValue>(Op)) { - // Insert a new store to null because we cannot modify the CFG here. - Builder.CreateStore(ConstantInt::getTrue(FI.getContext()), - UndefValue::get(Type::getInt1PtrTy(FI.getContext()))); + // Leave a marker since we can't modify the CFG here. + CreateNonTerminatorUnreachable(&FI); return eraseInstFromFunction(FI); } @@ -2618,53 +2696,28 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { return ExtractValueInst::Create(IV->getInsertedValueOperand(), makeArrayRef(exti, exte)); } - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Agg)) { - // We're extracting from an intrinsic, see if we're the only user, which - // allows us to simplify multiple result intrinsics to simpler things that - // just get one value. - if (II->hasOneUse()) { - // Check if we're grabbing the overflow bit or the result of a 'with - // overflow' intrinsic. If it's the latter we can remove the intrinsic + if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) { + // We're extracting from an overflow intrinsic, see if we're the only user, + // which allows us to simplify multiple result intrinsics to simpler + // things that just get one value. + if (WO->hasOneUse()) { + // Check if we're grabbing only the result of a 'with overflow' intrinsic // and replace it with a traditional binary instruction. - switch (II->getIntrinsicID()) { - case Intrinsic::uadd_with_overflow: - case Intrinsic::sadd_with_overflow: - if (*EV.idx_begin() == 0) { // Normal result. - Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - replaceInstUsesWith(*II, UndefValue::get(II->getType())); - eraseInstFromFunction(*II); - return BinaryOperator::CreateAdd(LHS, RHS); - } - - // If the normal result of the add is dead, and the RHS is a constant, - // we can transform this into a range comparison. - // overflow = uadd a, -4 --> overflow = icmp ugt a, 3 - if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow) - if (ConstantInt *CI = dyn_cast<ConstantInt>(II->getArgOperand(1))) - return new ICmpInst(ICmpInst::ICMP_UGT, II->getArgOperand(0), - ConstantExpr::getNot(CI)); - break; - case Intrinsic::usub_with_overflow: - case Intrinsic::ssub_with_overflow: - if (*EV.idx_begin() == 0) { // Normal result. - Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - replaceInstUsesWith(*II, UndefValue::get(II->getType())); - eraseInstFromFunction(*II); - return BinaryOperator::CreateSub(LHS, RHS); - } - break; - case Intrinsic::umul_with_overflow: - case Intrinsic::smul_with_overflow: - if (*EV.idx_begin() == 0) { // Normal result. - Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1); - replaceInstUsesWith(*II, UndefValue::get(II->getType())); - eraseInstFromFunction(*II); - return BinaryOperator::CreateMul(LHS, RHS); - } - break; - default: - break; + if (*EV.idx_begin() == 0) { + Instruction::BinaryOps BinOp = WO->getBinaryOp(); + Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); + replaceInstUsesWith(*WO, UndefValue::get(WO->getType())); + eraseInstFromFunction(*WO); + return BinaryOperator::Create(BinOp, LHS, RHS); } + + // If the normal result of the add is dead, and the RHS is a constant, + // we can transform this into a range comparison. + // overflow = uadd a, -4 --> overflow = icmp ugt a, 3 + if (WO->getIntrinsicID() == Intrinsic::uadd_with_overflow) + if (ConstantInt *CI = dyn_cast<ConstantInt>(WO->getRHS())) + return new ICmpInst(ICmpInst::ICMP_UGT, WO->getLHS(), + ConstantExpr::getNot(CI)); } } if (LoadInst *L = dyn_cast<LoadInst>(Agg)) @@ -2687,7 +2740,7 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { Builder.SetInsertPoint(L); Value *GEP = Builder.CreateInBoundsGEP(L->getType(), L->getPointerOperand(), Indices); - Instruction *NL = Builder.CreateLoad(GEP); + Instruction *NL = Builder.CreateLoad(EV.getType(), GEP); // Whatever aliasing information we had for the orignal load must also // hold for the smaller load, so propagate the annotations. AAMDNodes Nodes; @@ -3065,9 +3118,11 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { I->isTerminator()) return false; - // Do not sink alloca instructions out of the entry block. - if (isa<AllocaInst>(I) && I->getParent() == - &DestBlock->getParent()->getEntryBlock()) + // Do not sink static or dynamic alloca instructions. Static allocas must + // remain in the entry block, and dynamic allocas must not be sunk in between + // a stacksave / stackrestore pair, which would incorrectly shorten its + // lifetime. + if (isa<AllocaInst>(I)) return false; // Do not sink into catchswitch blocks. @@ -3093,13 +3148,35 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { ++NumSunkInst; // Also sink all related debug uses from the source basic block. Otherwise we - // get debug use before the def. - SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; + // get debug use before the def. Attempt to salvage debug uses first, to + // maximise the range variables have location for. If we cannot salvage, then + // mark the location undef: we know it was supposed to receive a new location + // here, but that computation has been sunk. + SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; findDbgUsers(DbgUsers, I); - for (auto *DII : DbgUsers) { + for (auto *DII : reverse(DbgUsers)) { if (DII->getParent() == SrcBlock) { - DII->moveBefore(&*InsertPos); - LLVM_DEBUG(dbgs() << "SINK: " << *DII << '\n'); + // dbg.value is in the same basic block as the sunk inst, see if we can + // salvage it. Clone a new copy of the instruction: on success we need + // both salvaged and unsalvaged copies. + SmallVector<DbgVariableIntrinsic *, 1> TmpUser{ + cast<DbgVariableIntrinsic>(DII->clone())}; + + if (!salvageDebugInfoForDbgValues(*I, TmpUser)) { + // We are unable to salvage: sink the cloned dbg.value, and mark the + // original as undef, terminating any earlier variable location. + LLVM_DEBUG(dbgs() << "SINK: " << *DII << '\n'); + TmpUser[0]->insertBefore(&*InsertPos); + Value *Undef = UndefValue::get(I->getType()); + DII->setOperand(0, MetadataAsValue::get(DII->getContext(), + ValueAsMetadata::get(Undef))); + } else { + // We successfully salvaged: place the salvaged dbg.value in the + // original location, and move the unmodified dbg.value to sink with + // the sunk inst. + TmpUser[0]->insertBefore(DII); + DII->moveBefore(&*InsertPos); + } } } return true; @@ -3294,7 +3371,8 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); - salvageDebugInfo(*Inst); + if (!salvageDebugInfo(*Inst)) + replaceDbgUsesWithUndef(Inst); Inst->eraseFromParent(); MadeIRChange = true; continue; @@ -3407,7 +3485,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, static bool combineInstructionsOverFunction( Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - OptimizationRemarkEmitter &ORE, bool ExpensiveCombines = true, + OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, bool ExpensiveCombines = true, LoopInfo *LI = nullptr) { auto &DL = F.getParent()->getDataLayout(); ExpensiveCombines |= EnableExpensiveCombines; @@ -3437,8 +3516,8 @@ static bool combineInstructionsOverFunction( MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, AA, - AC, TLI, DT, ORE, DL, LI); + InstCombiner IC(Worklist, Builder, F.hasMinSize(), ExpensiveCombines, AA, + AC, TLI, DT, ORE, BFI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; if (!IC.run()) @@ -3458,8 +3537,15 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto *LI = AM.getCachedResult<LoopAnalysis>(F); auto *AA = &AM.getResult<AAManager>(F); + const ModuleAnalysisManager &MAM = + AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + ProfileSummaryInfo *PSI = + MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto *BFI = (PSI && PSI->hasProfileSummary()) ? + &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, - ExpensiveCombines, LI)) + BFI, PSI, ExpensiveCombines, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3483,6 +3569,8 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); } bool InstructionCombiningPass::runOnFunction(Function &F) { @@ -3499,9 +3587,15 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { // Optional analyses. auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; + ProfileSummaryInfo *PSI = + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + BlockFrequencyInfo *BFI = + (PSI && PSI->hasProfileSummary()) ? + &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : + nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, - ExpensiveCombines, LI); + BFI, PSI, ExpensiveCombines, LI); } char InstructionCombiningPass::ID = 0; @@ -3514,6 +3608,8 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_END(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index f1558c75cb90..6821e214e921 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -1,9 +1,8 @@ //===- AddressSanitizer.cpp - memory error detector -----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -13,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/AddressSanitizer.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" @@ -25,7 +25,6 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/MachO.h" #include "llvm/IR/Argument.h" @@ -72,6 +71,7 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> @@ -94,9 +94,6 @@ static const uint64_t kDefaultShadowOffset32 = 1ULL << 29; static const uint64_t kDefaultShadowOffset64 = 1ULL << 44; static const uint64_t kDynamicShadowSentinel = std::numeric_limits<uint64_t>::max(); -static const uint64_t kIOSShadowOffset32 = 1ULL << 30; -static const uint64_t kIOSSimShadowOffset32 = 1ULL << 30; -static const uint64_t kIOSSimShadowOffset64 = kDefaultShadowOffset64; static const uint64_t kSmallX86_64ShadowOffsetBase = 0x7FFFFFFF; // < 2G. static const uint64_t kSmallX86_64ShadowOffsetAlignMask = ~0xFFFULL; static const uint64_t kLinuxKasan_ShadowOffset64 = 0xdffffc0000000000; @@ -112,6 +109,7 @@ static const uint64_t kNetBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kNetBSDKasan_ShadowOffset64 = 0xdfff900000000000; static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; +static const uint64_t kEmscriptenShadowOffset = 0; static const uint64_t kMyriadShadowScale = 5; static const uint64_t kMyriadMemoryOffset32 = 0x80000000ULL; @@ -275,6 +273,16 @@ static cl::opt<bool> ClInvalidPointerPairs( cl::desc("Instrument <, <=, >, >=, - with pointer operands"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInvalidPointerCmp( + "asan-detect-invalid-pointer-cmp", + cl::desc("Instrument <, <=, >, >= with pointer operands"), cl::Hidden, + cl::init(false)); + +static cl::opt<bool> ClInvalidPointerSub( + "asan-detect-invalid-pointer-sub", + cl::desc("Instrument - operations with pointer operands"), cl::Hidden, + cl::init(false)); + static cl::opt<unsigned> ClRealignStack( "asan-realign-stack", cl::desc("Realign stack to the value of this flag (power of two)"), @@ -311,10 +319,10 @@ static cl::opt<int> ClMappingScale("asan-mapping-scale", cl::desc("scale of asan shadow mapping"), cl::Hidden, cl::init(0)); -static cl::opt<unsigned long long> ClMappingOffset( - "asan-mapping-offset", - cl::desc("offset of asan shadow mapping [EXPERIMENTAL]"), cl::Hidden, - cl::init(0)); +static cl::opt<uint64_t> + ClMappingOffset("asan-mapping-offset", + cl::desc("offset of asan shadow mapping [EXPERIMENTAL]"), + cl::Hidden, cl::init(0)); // Optimization flags. Not user visible, used mostly for testing // and benchmarking the tool. @@ -393,87 +401,6 @@ STATISTIC(NumOptimizedAccessesToStackVar, namespace { -/// Frontend-provided metadata for source location. -struct LocationMetadata { - StringRef Filename; - int LineNo = 0; - int ColumnNo = 0; - - LocationMetadata() = default; - - bool empty() const { return Filename.empty(); } - - void parse(MDNode *MDN) { - assert(MDN->getNumOperands() == 3); - MDString *DIFilename = cast<MDString>(MDN->getOperand(0)); - Filename = DIFilename->getString(); - LineNo = - mdconst::extract<ConstantInt>(MDN->getOperand(1))->getLimitedValue(); - ColumnNo = - mdconst::extract<ConstantInt>(MDN->getOperand(2))->getLimitedValue(); - } -}; - -/// Frontend-provided metadata for global variables. -class GlobalsMetadata { -public: - struct Entry { - LocationMetadata SourceLoc; - StringRef Name; - bool IsDynInit = false; - bool IsBlacklisted = false; - - Entry() = default; - }; - - GlobalsMetadata() = default; - - void reset() { - inited_ = false; - Entries.clear(); - } - - void init(Module &M) { - assert(!inited_); - inited_ = true; - NamedMDNode *Globals = M.getNamedMetadata("llvm.asan.globals"); - if (!Globals) return; - for (auto MDN : Globals->operands()) { - // Metadata node contains the global and the fields of "Entry". - assert(MDN->getNumOperands() == 5); - auto *V = mdconst::extract_or_null<Constant>(MDN->getOperand(0)); - // The optimizer may optimize away a global entirely. - if (!V) continue; - auto *StrippedV = V->stripPointerCasts(); - auto *GV = dyn_cast<GlobalVariable>(StrippedV); - if (!GV) continue; - // We can already have an entry for GV if it was merged with another - // global. - Entry &E = Entries[GV]; - if (auto *Loc = cast_or_null<MDNode>(MDN->getOperand(1))) - E.SourceLoc.parse(Loc); - if (auto *Name = cast_or_null<MDString>(MDN->getOperand(2))) - E.Name = Name->getString(); - ConstantInt *IsDynInit = - mdconst::extract<ConstantInt>(MDN->getOperand(3)); - E.IsDynInit |= IsDynInit->isOne(); - ConstantInt *IsBlacklisted = - mdconst::extract<ConstantInt>(MDN->getOperand(4)); - E.IsBlacklisted |= IsBlacklisted->isOne(); - } - } - - /// Returns metadata entry for a given global. - Entry get(GlobalVariable *G) const { - auto Pos = Entries.find(G); - return (Pos != Entries.end()) ? Pos->second : Entry(); - } - -private: - bool inited_ = false; - DenseMap<GlobalVariable *, Entry> Entries; -}; - /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) ADD-or-OR Offset. /// If InGlobal is true, then @@ -499,7 +426,6 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsPPC64 = TargetTriple.getArch() == Triple::ppc64 || TargetTriple.getArch() == Triple::ppc64le; bool IsSystemZ = TargetTriple.getArch() == Triple::systemz; - bool IsX86 = TargetTriple.getArch() == Triple::x86; bool IsX86_64 = TargetTriple.getArch() == Triple::x86_64; bool IsMIPS32 = TargetTriple.isMIPS32(); bool IsMIPS64 = TargetTriple.isMIPS64(); @@ -508,6 +434,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsWindows = TargetTriple.isOSWindows(); bool IsFuchsia = TargetTriple.isOSFuchsia(); bool IsMyriad = TargetTriple.getVendor() == llvm::Triple::Myriad; + bool IsEmscripten = TargetTriple.isOSEmscripten(); ShadowMapping Mapping; @@ -526,10 +453,11 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, else if (IsNetBSD) Mapping.Offset = kNetBSD_ShadowOffset32; else if (IsIOS) - // If we're targeting iOS and x86, the binary is built for iOS simulator. - Mapping.Offset = IsX86 ? kIOSSimShadowOffset32 : kIOSShadowOffset32; + Mapping.Offset = kDynamicShadowSentinel; else if (IsWindows) Mapping.Offset = kWindowsShadowOffset32; + else if (IsEmscripten) + Mapping.Offset = kEmscriptenShadowOffset; else if (IsMyriad) { uint64_t ShadowOffset = (kMyriadMemoryOffset32 + kMyriadMemorySize32 - (kMyriadMemorySize32 >> Mapping.Scale)); @@ -566,10 +494,7 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, } else if (IsMIPS64) Mapping.Offset = kMIPS64_ShadowOffset64; else if (IsIOS) - // If we're targeting iOS and x86, the binary is built for iOS simulator. - // We are using dynamic shadow offset on the 64-bit devices. - Mapping.Offset = - IsX86_64 ? kIOSSimShadowOffset64 : kDynamicShadowSentinel; + Mapping.Offset = kDynamicShadowSentinel; else if (IsAArch64) Mapping.Offset = kAArch64_ShadowOffset64; else @@ -607,27 +532,53 @@ static size_t RedzoneSizeForScale(int MappingScale) { namespace { -/// AddressSanitizer: instrument the code in module to find memory bugs. -struct AddressSanitizer : public FunctionPass { - // Pass identification, replacement for typeid +/// Module analysis for getting various metadata about the module. +class ASanGlobalsMetadataWrapperPass : public ModulePass { +public: static char ID; - explicit AddressSanitizer(bool CompileKernel = false, bool Recover = false, - bool UseAfterScope = false) - : FunctionPass(ID), UseAfterScope(UseAfterScope || ClUseAfterScope) { - this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; - this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? - ClEnableKasan : CompileKernel; - initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); + ASanGlobalsMetadataWrapperPass() : ModulePass(ID) { + initializeASanGlobalsMetadataWrapperPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + GlobalsMD = GlobalsMetadata(M); + return false; } StringRef getPassName() const override { - return "AddressSanitizerFunctionPass"; + return "ASanGlobalsMetadataWrapperPass"; } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesAll(); + } + + GlobalsMetadata &getGlobalsMD() { return GlobalsMD; } + +private: + GlobalsMetadata GlobalsMD; +}; + +char ASanGlobalsMetadataWrapperPass::ID = 0; + +/// AddressSanitizer: instrument the code in module to find memory bugs. +struct AddressSanitizer { + AddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD, + bool CompileKernel = false, bool Recover = false, + bool UseAfterScope = false) + : UseAfterScope(UseAfterScope || ClUseAfterScope), GlobalsMD(GlobalsMD) { + this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; + this->CompileKernel = + ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel; + + C = &(M.getContext()); + LongSize = M.getDataLayout().getPointerSizeInBits(); + IntptrTy = Type::getIntNTy(*C, LongSize); + TargetTriple = Triple(M.getTargetTriple()); + + Mapping = getShadowMapping(TargetTriple, LongSize, this->CompileKernel); } uint64_t getAllocaSizeInBytes(const AllocaInst &AI) const { @@ -672,14 +623,10 @@ struct AddressSanitizer : public FunctionPass { Value *SizeArgument, uint32_t Exp); void instrumentMemIntrinsic(MemIntrinsic *MI); Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); - bool runOnFunction(Function &F) override; + bool instrumentFunction(Function &F, const TargetLibraryInfo *TLI); bool maybeInsertAsanInitAtFunctionEntry(Function &F); void maybeInsertDynamicShadowAtFunctionEntry(Function &F); void markEscapedLocalAllocas(Function &F); - bool doInitialization(Module &M) override; - bool doFinalization(Module &M) override; - - DominatorTree &getDominatorTree() const { return *DT; } private: friend struct FunctionStackPoisoner; @@ -715,36 +662,68 @@ private: bool UseAfterScope; Type *IntptrTy; ShadowMapping Mapping; - DominatorTree *DT; - Function *AsanHandleNoReturnFunc; - Function *AsanPtrCmpFunction, *AsanPtrSubFunction; + FunctionCallee AsanHandleNoReturnFunc; + FunctionCallee AsanPtrCmpFunction, AsanPtrSubFunction; Constant *AsanShadowGlobal; // These arrays is indexed by AccessIsWrite, Experiment and log2(AccessSize). - Function *AsanErrorCallback[2][2][kNumberOfAccessSizes]; - Function *AsanMemoryAccessCallback[2][2][kNumberOfAccessSizes]; + FunctionCallee AsanErrorCallback[2][2][kNumberOfAccessSizes]; + FunctionCallee AsanMemoryAccessCallback[2][2][kNumberOfAccessSizes]; // These arrays is indexed by AccessIsWrite and Experiment. - Function *AsanErrorCallbackSized[2][2]; - Function *AsanMemoryAccessCallbackSized[2][2]; + FunctionCallee AsanErrorCallbackSized[2][2]; + FunctionCallee AsanMemoryAccessCallbackSized[2][2]; - Function *AsanMemmove, *AsanMemcpy, *AsanMemset; + FunctionCallee AsanMemmove, AsanMemcpy, AsanMemset; InlineAsm *EmptyAsm; Value *LocalDynamicShadow = nullptr; GlobalsMetadata GlobalsMD; DenseMap<const AllocaInst *, bool> ProcessedAllocas; }; -class AddressSanitizerModule : public ModulePass { +class AddressSanitizerLegacyPass : public FunctionPass { public: - // Pass identification, replacement for typeid static char ID; - explicit AddressSanitizerModule(bool CompileKernel = false, - bool Recover = false, - bool UseGlobalsGC = true, - bool UseOdrIndicator = false) - : ModulePass(ID), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), + explicit AddressSanitizerLegacyPass(bool CompileKernel = false, + bool Recover = false, + bool UseAfterScope = false) + : FunctionPass(ID), CompileKernel(CompileKernel), Recover(Recover), + UseAfterScope(UseAfterScope) { + initializeAddressSanitizerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "AddressSanitizerFunctionPass"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ASanGlobalsMetadataWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + GlobalsMetadata &GlobalsMD = + getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + AddressSanitizer ASan(*F.getParent(), GlobalsMD, CompileKernel, Recover, + UseAfterScope); + return ASan.instrumentFunction(F, TLI); + } + +private: + bool CompileKernel; + bool Recover; + bool UseAfterScope; +}; + +class ModuleAddressSanitizer { +public: + ModuleAddressSanitizer(Module &M, GlobalsMetadata &GlobalsMD, + bool CompileKernel = false, bool Recover = false, + bool UseGlobalsGC = true, bool UseOdrIndicator = false) + : GlobalsMD(GlobalsMD), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), // Enable aliases as they should have no downside with ODR indicators. UsePrivateAlias(UseOdrIndicator || ClUsePrivateAlias), UseOdrIndicator(UseOdrIndicator || ClUseOdrIndicator), @@ -759,10 +738,15 @@ public: this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel; + + C = &(M.getContext()); + int LongSize = M.getDataLayout().getPointerSizeInBits(); + IntptrTy = Type::getIntNTy(*C, LongSize); + TargetTriple = Triple(M.getTargetTriple()); + Mapping = getShadowMapping(TargetTriple, LongSize, this->CompileKernel); } - bool runOnModule(Module &M) override; - StringRef getPassName() const override { return "AddressSanitizerModule"; } + bool instrumentModule(Module &); private: void initializeCallbacks(Module &M); @@ -810,19 +794,54 @@ private: LLVMContext *C; Triple TargetTriple; ShadowMapping Mapping; - Function *AsanPoisonGlobals; - Function *AsanUnpoisonGlobals; - Function *AsanRegisterGlobals; - Function *AsanUnregisterGlobals; - Function *AsanRegisterImageGlobals; - Function *AsanUnregisterImageGlobals; - Function *AsanRegisterElfGlobals; - Function *AsanUnregisterElfGlobals; + FunctionCallee AsanPoisonGlobals; + FunctionCallee AsanUnpoisonGlobals; + FunctionCallee AsanRegisterGlobals; + FunctionCallee AsanUnregisterGlobals; + FunctionCallee AsanRegisterImageGlobals; + FunctionCallee AsanUnregisterImageGlobals; + FunctionCallee AsanRegisterElfGlobals; + FunctionCallee AsanUnregisterElfGlobals; Function *AsanCtorFunction = nullptr; Function *AsanDtorFunction = nullptr; }; +class ModuleAddressSanitizerLegacyPass : public ModulePass { +public: + static char ID; + + explicit ModuleAddressSanitizerLegacyPass(bool CompileKernel = false, + bool Recover = false, + bool UseGlobalGC = true, + bool UseOdrIndicator = false) + : ModulePass(ID), CompileKernel(CompileKernel), Recover(Recover), + UseGlobalGC(UseGlobalGC), UseOdrIndicator(UseOdrIndicator) { + initializeModuleAddressSanitizerLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { return "ModuleAddressSanitizer"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ASanGlobalsMetadataWrapperPass>(); + } + + bool runOnModule(Module &M) override { + GlobalsMetadata &GlobalsMD = + getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); + ModuleAddressSanitizer ASanModule(M, GlobalsMD, CompileKernel, Recover, + UseGlobalGC, UseOdrIndicator); + return ASanModule.instrumentModule(M); + } + +private: + bool CompileKernel; + bool Recover; + bool UseGlobalGC; + bool UseOdrIndicator; +}; + // Stack poisoning does not play well with exception handling. // When an exception is thrown, we essentially bypass the code // that unpoisones the stack. This is why the run-time library has @@ -846,11 +865,11 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { SmallVector<Instruction *, 8> RetVec; unsigned StackAlignment; - Function *AsanStackMallocFunc[kMaxAsanStackMallocSizeClass + 1], - *AsanStackFreeFunc[kMaxAsanStackMallocSizeClass + 1]; - Function *AsanSetShadowFunc[0x100] = {}; - Function *AsanPoisonStackMemoryFunc, *AsanUnpoisonStackMemoryFunc; - Function *AsanAllocaPoisonFunc, *AsanAllocasUnpoisonFunc; + FunctionCallee AsanStackMallocFunc[kMaxAsanStackMallocSizeClass + 1], + AsanStackFreeFunc[kMaxAsanStackMallocSizeClass + 1]; + FunctionCallee AsanSetShadowFunc[0x100] = {}; + FunctionCallee AsanPoisonStackMemoryFunc, AsanUnpoisonStackMemoryFunc; + FunctionCallee AsanAllocaPoisonFunc, AsanAllocasUnpoisonFunc; // Stores a place and arguments of poisoning/unpoisoning call for alloca. struct AllocaPoisonCall { @@ -861,6 +880,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { }; SmallVector<AllocaPoisonCall, 8> DynamicAllocaPoisonCallVec; SmallVector<AllocaPoisonCall, 8> StaticAllocaPoisonCallVec; + bool HasUntracedLifetimeIntrinsic = false; SmallVector<AllocaInst *, 1> DynamicAllocaVec; SmallVector<IntrinsicInst *, 1> StackRestoreVec; @@ -876,13 +896,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { std::unique_ptr<CallInst> EmptyInlineAsm; FunctionStackPoisoner(Function &F, AddressSanitizer &ASan) - : F(F), - ASan(ASan), - DIB(*F.getParent(), /*AllowUnresolved*/ false), - C(ASan.C), - IntptrTy(ASan.IntptrTy), - IntptrPtrTy(PointerType::get(IntptrTy, 0)), - Mapping(ASan.Mapping), + : F(F), ASan(ASan), DIB(*F.getParent(), /*AllowUnresolved*/ false), + C(ASan.C), IntptrTy(ASan.IntptrTy), + IntptrPtrTy(PointerType::get(IntptrTy, 0)), Mapping(ASan.Mapping), StackAlignment(1 << Mapping.Scale), EmptyInlineAsm(CallInst::Create(ASan.EmptyAsm)) {} @@ -899,6 +915,14 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { initializeCallbacks(*F.getParent()); + if (HasUntracedLifetimeIntrinsic) { + // If there are lifetime intrinsics which couldn't be traced back to an + // alloca, we may not know exactly when a variable enters scope, and + // therefore should "fail safe" by not poisoning them. + StaticAllocaPoisonCallVec.clear(); + DynamicAllocaPoisonCallVec.clear(); + } + processDynamicAllocas(); processStaticAllocas(); @@ -950,8 +974,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { DynamicAreaOffset); } - IRB.CreateCall(AsanAllocasUnpoisonFunc, - {IRB.CreateLoad(DynamicAllocaLayout), DynamicAreaPtr}); + IRB.CreateCall( + AsanAllocasUnpoisonFunc, + {IRB.CreateLoad(IntptrTy, DynamicAllocaLayout), DynamicAreaPtr}); } // Unpoison dynamic allocas redzones. @@ -1018,8 +1043,14 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { !ConstantInt::isValueValidForType(IntptrTy, SizeValue)) return; // Find alloca instruction that corresponds to llvm.lifetime argument. - AllocaInst *AI = findAllocaForValue(II.getArgOperand(1)); - if (!AI || !ASan.isInterestingAlloca(*AI)) + AllocaInst *AI = + llvm::findAllocaForValue(II.getArgOperand(1), AllocaForValue); + if (!AI) { + HasUntracedLifetimeIntrinsic = true; + return; + } + // We're interested only in allocas we can handle. + if (!ASan.isInterestingAlloca(*AI)) return; bool DoPoison = (ID == Intrinsic::lifetime_end); AllocaPoisonCall APC = {&II, AI, SizeValue, DoPoison}; @@ -1042,16 +1073,6 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { // ---------------------- Helpers. void initializeCallbacks(Module &M); - bool doesDominateAllExits(const Instruction *I) const { - for (auto Ret : RetVec) { - if (!ASan.getDominatorTree().dominates(I, Ret)) return false; - } - return true; - } - - /// Finds alloca where the value comes from. - AllocaInst *findAllocaForValue(Value *V); - // Copies bytes from ShadowBytes into shadow memory for indexes where // ShadowMask is not zero. If ShadowMask[i] is zero, we assume that // ShadowBytes[i] is constantly zero and doesn't need to be overwritten. @@ -1074,16 +1095,111 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { } // end anonymous namespace -char AddressSanitizer::ID = 0; +void LocationMetadata::parse(MDNode *MDN) { + assert(MDN->getNumOperands() == 3); + MDString *DIFilename = cast<MDString>(MDN->getOperand(0)); + Filename = DIFilename->getString(); + LineNo = mdconst::extract<ConstantInt>(MDN->getOperand(1))->getLimitedValue(); + ColumnNo = + mdconst::extract<ConstantInt>(MDN->getOperand(2))->getLimitedValue(); +} + +// FIXME: It would be cleaner to instead attach relevant metadata to the globals +// we want to sanitize instead and reading this metadata on each pass over a +// function instead of reading module level metadata at first. +GlobalsMetadata::GlobalsMetadata(Module &M) { + NamedMDNode *Globals = M.getNamedMetadata("llvm.asan.globals"); + if (!Globals) + return; + for (auto MDN : Globals->operands()) { + // Metadata node contains the global and the fields of "Entry". + assert(MDN->getNumOperands() == 5); + auto *V = mdconst::extract_or_null<Constant>(MDN->getOperand(0)); + // The optimizer may optimize away a global entirely. + if (!V) + continue; + auto *StrippedV = V->stripPointerCasts(); + auto *GV = dyn_cast<GlobalVariable>(StrippedV); + if (!GV) + continue; + // We can already have an entry for GV if it was merged with another + // global. + Entry &E = Entries[GV]; + if (auto *Loc = cast_or_null<MDNode>(MDN->getOperand(1))) + E.SourceLoc.parse(Loc); + if (auto *Name = cast_or_null<MDString>(MDN->getOperand(2))) + E.Name = Name->getString(); + ConstantInt *IsDynInit = mdconst::extract<ConstantInt>(MDN->getOperand(3)); + E.IsDynInit |= IsDynInit->isOne(); + ConstantInt *IsBlacklisted = + mdconst::extract<ConstantInt>(MDN->getOperand(4)); + E.IsBlacklisted |= IsBlacklisted->isOne(); + } +} + +AnalysisKey ASanGlobalsMetadataAnalysis::Key; + +GlobalsMetadata ASanGlobalsMetadataAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + return GlobalsMetadata(M); +} + +AddressSanitizerPass::AddressSanitizerPass(bool CompileKernel, bool Recover, + bool UseAfterScope) + : CompileKernel(CompileKernel), Recover(Recover), + UseAfterScope(UseAfterScope) {} + +PreservedAnalyses AddressSanitizerPass::run(Function &F, + AnalysisManager<Function> &AM) { + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + auto &MAM = MAMProxy.getManager(); + Module &M = *F.getParent(); + if (auto *R = MAM.getCachedResult<ASanGlobalsMetadataAnalysis>(M)) { + const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + AddressSanitizer Sanitizer(M, *R, CompileKernel, Recover, UseAfterScope); + if (Sanitizer.instrumentFunction(F, TLI)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); + } + + report_fatal_error( + "The ASanGlobalsMetadataAnalysis is required to run before " + "AddressSanitizer can run"); + return PreservedAnalyses::all(); +} + +ModuleAddressSanitizerPass::ModuleAddressSanitizerPass(bool CompileKernel, + bool Recover, + bool UseGlobalGC, + bool UseOdrIndicator) + : CompileKernel(CompileKernel), Recover(Recover), UseGlobalGC(UseGlobalGC), + UseOdrIndicator(UseOdrIndicator) {} + +PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, + AnalysisManager<Module> &AM) { + GlobalsMetadata &GlobalsMD = AM.getResult<ASanGlobalsMetadataAnalysis>(M); + ModuleAddressSanitizer Sanitizer(M, GlobalsMD, CompileKernel, Recover, + UseGlobalGC, UseOdrIndicator); + if (Sanitizer.instrumentModule(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +INITIALIZE_PASS(ASanGlobalsMetadataWrapperPass, "asan-globals-md", + "Read metadata to mark which globals should be instrumented " + "when running ASan.", + false, true) + +char AddressSanitizerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN( - AddressSanitizer, "asan", + AddressSanitizerLegacyPass, "asan", "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ASanGlobalsMetadataWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END( - AddressSanitizer, "asan", + AddressSanitizerLegacyPass, "asan", "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) @@ -1091,24 +1207,22 @@ FunctionPass *llvm::createAddressSanitizerFunctionPass(bool CompileKernel, bool Recover, bool UseAfterScope) { assert(!CompileKernel || Recover); - return new AddressSanitizer(CompileKernel, Recover, UseAfterScope); + return new AddressSanitizerLegacyPass(CompileKernel, Recover, UseAfterScope); } -char AddressSanitizerModule::ID = 0; +char ModuleAddressSanitizerLegacyPass::ID = 0; INITIALIZE_PASS( - AddressSanitizerModule, "asan-module", + ModuleAddressSanitizerLegacyPass, "asan-module", "AddressSanitizer: detects use-after-free and out-of-bounds bugs." "ModulePass", false, false) -ModulePass *llvm::createAddressSanitizerModulePass(bool CompileKernel, - bool Recover, - bool UseGlobalsGC, - bool UseOdrIndicator) { +ModulePass *llvm::createModuleAddressSanitizerLegacyPassPass( + bool CompileKernel, bool Recover, bool UseGlobalsGC, bool UseOdrIndicator) { assert(!CompileKernel || Recover); - return new AddressSanitizerModule(CompileKernel, Recover, UseGlobalsGC, - UseOdrIndicator); + return new ModuleAddressSanitizerLegacyPass(CompileKernel, Recover, + UseGlobalsGC, UseOdrIndicator); } static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { @@ -1312,11 +1426,24 @@ static bool isPointerOperand(Value *V) { // This is a rough heuristic; it may cause both false positives and // false negatives. The proper implementation requires cooperation with // the frontend. -static bool isInterestingPointerComparisonOrSubtraction(Instruction *I) { +static bool isInterestingPointerComparison(Instruction *I) { if (ICmpInst *Cmp = dyn_cast<ICmpInst>(I)) { - if (!Cmp->isRelational()) return false; - } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { - if (BO->getOpcode() != Instruction::Sub) return false; + if (!Cmp->isRelational()) + return false; + } else { + return false; + } + return isPointerOperand(I->getOperand(0)) && + isPointerOperand(I->getOperand(1)); +} + +// This is a rough heuristic; it may cause both false positives and +// false negatives. The proper implementation requires cooperation with +// the frontend. +static bool isInterestingPointerSubtraction(Instruction *I) { + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + if (BO->getOpcode() != Instruction::Sub) + return false; } else { return false; } @@ -1328,13 +1455,16 @@ bool AddressSanitizer::GlobalIsLinkerInitialized(GlobalVariable *G) { // If a global variable does not have dynamic initialization we don't // have to instrument it. However, if a global does not have initializer // at all, we assume it has dynamic initializer (in other TU). + // + // FIXME: Metadata should be attched directly to the global directly instead + // of being added to llvm.asan.globals. return G->hasInitializer() && !GlobalsMD.get(G).IsDynInit; } void AddressSanitizer::instrumentPointerComparisonOrSubtraction( Instruction *I) { IRBuilder<> IRB(I); - Function *F = isa<ICmpInst>(I) ? AsanPtrCmpFunction : AsanPtrSubFunction; + FunctionCallee F = isa<ICmpInst>(I) ? AsanPtrCmpFunction : AsanPtrSubFunction; Value *Param[2] = {I->getOperand(0), I->getOperand(1)}; for (Value *&i : Param) { if (i->getType()->isPointerTy()) @@ -1392,7 +1522,7 @@ static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, IRBuilder<> IRB(InsertBefore); InstrumentedAddress = - IRB.CreateGEP(Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); + IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment, Granularity, ElemTypeSize, IsWrite, SizeArgument, UseCalls, Exp); @@ -1553,7 +1683,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Value *ShadowPtr = memToShadow(AddrLong, IRB); Value *CmpVal = Constant::getNullValue(ShadowTy); Value *ShadowValue = - IRB.CreateLoad(IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); + IRB.CreateLoad(ShadowTy, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); Value *Cmp = IRB.CreateICmpNE(ShadowValue, CmpVal); size_t Granularity = 1ULL << Mapping.Scale; @@ -1612,7 +1742,7 @@ void AddressSanitizer::instrumentUnusualSizeOrAlignment( } } -void AddressSanitizerModule::poisonOneInitializer(Function &GlobalInit, +void ModuleAddressSanitizer::poisonOneInitializer(Function &GlobalInit, GlobalValue *ModuleName) { // Set up the arguments to our poison/unpoison functions. IRBuilder<> IRB(&GlobalInit.front(), @@ -1628,7 +1758,7 @@ void AddressSanitizerModule::poisonOneInitializer(Function &GlobalInit, CallInst::Create(AsanUnpoisonGlobals, "", RI); } -void AddressSanitizerModule::createInitializerPoisonCalls( +void ModuleAddressSanitizer::createInitializerPoisonCalls( Module &M, GlobalValue *ModuleName) { GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors"); if (!GV) @@ -1653,10 +1783,12 @@ void AddressSanitizerModule::createInitializerPoisonCalls( } } -bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { +bool ModuleAddressSanitizer::ShouldInstrumentGlobal(GlobalVariable *G) { Type *Ty = G->getValueType(); LLVM_DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); + // FIXME: Metadata should be attched directly to the global directly instead + // of being added to llvm.asan.globals. if (GlobalsMD.get(G).IsBlacklisted) return false; if (!Ty->isSized()) return false; if (!G->hasInitializer()) return false; @@ -1768,7 +1900,7 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // On Mach-O platforms, we emit global metadata in a separate section of the // binary in order to allow the linker to properly dead strip. This is only // supported on recent versions of ld64. -bool AddressSanitizerModule::ShouldUseMachOGlobalsSection() const { +bool ModuleAddressSanitizer::ShouldUseMachOGlobalsSection() const { if (!TargetTriple.isOSBinFormatMachO()) return false; @@ -1782,7 +1914,7 @@ bool AddressSanitizerModule::ShouldUseMachOGlobalsSection() const { return false; } -StringRef AddressSanitizerModule::getGlobalMetadataSection() const { +StringRef ModuleAddressSanitizer::getGlobalMetadataSection() const { switch (TargetTriple.getObjectFormat()) { case Triple::COFF: return ".ASAN$GL"; case Triple::ELF: return "asan_globals"; @@ -1792,52 +1924,39 @@ StringRef AddressSanitizerModule::getGlobalMetadataSection() const { llvm_unreachable("unsupported object format"); } -void AddressSanitizerModule::initializeCallbacks(Module &M) { +void ModuleAddressSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); // Declare our poisoning and unpoisoning functions. - AsanPoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy)); - AsanPoisonGlobals->setLinkage(Function::ExternalLinkage); - AsanUnpoisonGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanUnpoisonGlobalsName, IRB.getVoidTy())); - AsanUnpoisonGlobals->setLinkage(Function::ExternalLinkage); + AsanPoisonGlobals = + M.getOrInsertFunction(kAsanPoisonGlobalsName, IRB.getVoidTy(), IntptrTy); + AsanUnpoisonGlobals = + M.getOrInsertFunction(kAsanUnpoisonGlobalsName, IRB.getVoidTy()); // Declare functions that register/unregister globals. - AsanRegisterGlobals = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy)); - AsanRegisterGlobals->setLinkage(Function::ExternalLinkage); - AsanUnregisterGlobals = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanUnregisterGlobalsName, IRB.getVoidTy(), - IntptrTy, IntptrTy)); - AsanUnregisterGlobals->setLinkage(Function::ExternalLinkage); + AsanRegisterGlobals = M.getOrInsertFunction( + kAsanRegisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy); + AsanUnregisterGlobals = M.getOrInsertFunction( + kAsanUnregisterGlobalsName, IRB.getVoidTy(), IntptrTy, IntptrTy); // Declare the functions that find globals in a shared object and then invoke // the (un)register function on them. - AsanRegisterImageGlobals = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy)); - AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage); - - AsanUnregisterImageGlobals = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy)); - AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage); + AsanRegisterImageGlobals = M.getOrInsertFunction( + kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy); + AsanUnregisterImageGlobals = M.getOrInsertFunction( + kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy); - AsanRegisterElfGlobals = checkSanitizerInterfaceFunction( + AsanRegisterElfGlobals = M.getOrInsertFunction(kAsanRegisterElfGlobalsName, IRB.getVoidTy(), - IntptrTy, IntptrTy, IntptrTy)); - AsanRegisterElfGlobals->setLinkage(Function::ExternalLinkage); - - AsanUnregisterElfGlobals = checkSanitizerInterfaceFunction( + IntptrTy, IntptrTy, IntptrTy); + AsanUnregisterElfGlobals = M.getOrInsertFunction(kAsanUnregisterElfGlobalsName, IRB.getVoidTy(), - IntptrTy, IntptrTy, IntptrTy)); - AsanUnregisterElfGlobals->setLinkage(Function::ExternalLinkage); + IntptrTy, IntptrTy, IntptrTy); } // Put the metadata and the instrumented global in the same group. This ensures // that the metadata is discarded if the instrumented global is discarded. -void AddressSanitizerModule::SetComdatForGlobalMetadata( +void ModuleAddressSanitizer::SetComdatForGlobalMetadata( GlobalVariable *G, GlobalVariable *Metadata, StringRef InternalSuffix) { Module &M = *G->getParent(); Comdat *C = G->getComdat(); @@ -1875,7 +1994,7 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata( // Create a separate metadata global and put it in the appropriate ASan // global registration section. GlobalVariable * -AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer, +ModuleAddressSanitizer::CreateMetadataGlobal(Module &M, Constant *Initializer, StringRef OriginalName) { auto Linkage = TargetTriple.isOSBinFormatMachO() ? GlobalVariable::InternalLinkage @@ -1887,7 +2006,7 @@ AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer, return Metadata; } -IRBuilder<> AddressSanitizerModule::CreateAsanModuleDtor(Module &M) { +IRBuilder<> ModuleAddressSanitizer::CreateAsanModuleDtor(Module &M) { AsanDtorFunction = Function::Create(FunctionType::get(Type::getVoidTy(*C), false), GlobalValue::InternalLinkage, kAsanModuleDtorName, &M); @@ -1896,7 +2015,7 @@ IRBuilder<> AddressSanitizerModule::CreateAsanModuleDtor(Module &M) { return IRBuilder<>(ReturnInst::Create(*C, AsanDtorBB)); } -void AddressSanitizerModule::InstrumentGlobalsCOFF( +void ModuleAddressSanitizer::InstrumentGlobalsCOFF( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers) { assert(ExtendedGlobals.size() == MetadataInitializers.size()); @@ -1920,7 +2039,7 @@ void AddressSanitizerModule::InstrumentGlobalsCOFF( } } -void AddressSanitizerModule::InstrumentGlobalsELF( +void ModuleAddressSanitizer::InstrumentGlobalsELF( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers, const std::string &UniqueModuleId) { @@ -1979,7 +2098,7 @@ void AddressSanitizerModule::InstrumentGlobalsELF( IRB.CreatePointerCast(StopELFMetadata, IntptrTy)}); } -void AddressSanitizerModule::InstrumentGlobalsMachO( +void ModuleAddressSanitizer::InstrumentGlobalsMachO( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers) { assert(ExtendedGlobals.size() == MetadataInitializers.size()); @@ -2036,7 +2155,7 @@ void AddressSanitizerModule::InstrumentGlobalsMachO( {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); } -void AddressSanitizerModule::InstrumentGlobalsWithMetadataArray( +void ModuleAddressSanitizer::InstrumentGlobalsWithMetadataArray( IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, ArrayRef<Constant *> MetadataInitializers) { assert(ExtendedGlobals.size() == MetadataInitializers.size()); @@ -2070,9 +2189,9 @@ void AddressSanitizerModule::InstrumentGlobalsWithMetadataArray( // redzones and inserts this function into llvm.global_ctors. // Sets *CtorComdat to true if the global registration code emitted into the // asan constructor is comdat-compatible. -bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool *CtorComdat) { +bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, + bool *CtorComdat) { *CtorComdat = false; - GlobalsMD.init(M); SmallVector<GlobalVariable *, 16> GlobalsToChange; @@ -2115,6 +2234,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool static const uint64_t kMaxGlobalRedzone = 1 << 18; GlobalVariable *G = GlobalsToChange[i]; + // FIXME: Metadata should be attched directly to the global directly instead + // of being added to llvm.asan.globals. auto MD = GlobalsMD.get(G); StringRef NameForGlobal = G->getName(); // Create string holding the global name (use global name from metadata @@ -2271,7 +2392,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool return true; } -int AddressSanitizerModule::GetAsanVersion(const Module &M) const { +int ModuleAddressSanitizer::GetAsanVersion(const Module &M) const { int LongSize = M.getDataLayout().getPointerSizeInBits(); bool isAndroid = Triple(M.getTargetTriple()).isAndroid(); int Version = 8; @@ -2281,12 +2402,7 @@ int AddressSanitizerModule::GetAsanVersion(const Module &M) const { return Version; } -bool AddressSanitizerModule::runOnModule(Module &M) { - C = &(M.getContext()); - int LongSize = M.getDataLayout().getPointerSizeInBits(); - IntptrTy = Type::getIntNTy(*C, LongSize); - TargetTriple = Triple(M.getTargetTriple()); - Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel); +bool ModuleAddressSanitizer::instrumentModule(Module &M) { initializeCallbacks(M); if (CompileKernel) @@ -2346,51 +2462,49 @@ void AddressSanitizer::initializeCallbacks(Module &M) { Args2.push_back(ExpType); Args1.push_back(ExpType); } - AsanErrorCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + TypeStr + "_n" + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false))); + AsanErrorCallbackSized[AccessIsWrite][Exp] = M.getOrInsertFunction( + kAsanReportErrorTemplate + ExpStr + TypeStr + "_n" + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false)); - AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false))); + AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false)); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( + M.getOrInsertFunction( kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false))); + FunctionType::get(IRB.getVoidTy(), Args1, false)); AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( + M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false))); + FunctionType::get(IRB.getVoidTy(), Args1, false)); } } } const std::string MemIntrinCallbackPrefix = CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; - AsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); - AsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); - AsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy)); - - AsanHandleNoReturnFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy())); - - AsanPtrCmpFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy)); - AsanPtrSubFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy)); + AsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy); + AsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy); + AsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt32Ty(), IntptrTy); + + AsanHandleNoReturnFunc = + M.getOrInsertFunction(kAsanHandleNoReturnName, IRB.getVoidTy()); + + AsanPtrCmpFunction = + M.getOrInsertFunction(kAsanPtrCmp, IRB.getVoidTy(), IntptrTy, IntptrTy); + AsanPtrSubFunction = + M.getOrInsertFunction(kAsanPtrSub, IRB.getVoidTy(), IntptrTy, IntptrTy); // We insert an empty inline asm after __asan_report* to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), @@ -2400,25 +2514,6 @@ void AddressSanitizer::initializeCallbacks(Module &M) { ArrayType::get(IRB.getInt8Ty(), 0)); } -// virtual -bool AddressSanitizer::doInitialization(Module &M) { - // Initialize the private fields. No one has accessed them before. - GlobalsMD.init(M); - - C = &(M.getContext()); - LongSize = M.getDataLayout().getPointerSizeInBits(); - IntptrTy = Type::getIntNTy(*C, LongSize); - TargetTriple = Triple(M.getTargetTriple()); - - Mapping = getShadowMapping(TargetTriple, LongSize, CompileKernel); - return true; -} - -bool AddressSanitizer::doFinalization(Module &M) { - GlobalsMD.reset(); - return false; -} - bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { // For each NSObject descendant having a +load method, this method is invoked // by the ObjC runtime before any of the static constructors is called. @@ -2428,7 +2523,7 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { // We cannot just ignore these methods, because they may call other // instrumented functions. if (F.getName().find(" load]") != std::string::npos) { - Function *AsanInitFunction = + FunctionCallee AsanInitFunction = declareSanitizerInitFunction(*F.getParent(), kAsanInitName, {}); IRBuilder<> IRB(&F.front(), F.front().begin()); IRB.CreateCall(AsanInitFunction, {}); @@ -2460,7 +2555,7 @@ void AddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { } else { Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( kAsanShadowMemoryDynamicAddress, IntptrTy); - LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); + LocalDynamicShadow = IRB.CreateLoad(IntptrTy, GlobalDynamicAddress); } } @@ -2492,7 +2587,8 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) { } } -bool AddressSanitizer::runOnFunction(Function &F) { +bool AddressSanitizer::instrumentFunction(Function &F, + const TargetLibraryInfo *TLI) { if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; if (F.getName().startswith("__asan_")) return false; @@ -2511,7 +2607,6 @@ bool AddressSanitizer::runOnFunction(Function &F) { LLVM_DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); initializeCallbacks(*F.getParent()); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); FunctionStateRAII CleanupObj(this); @@ -2532,8 +2627,6 @@ bool AddressSanitizer::runOnFunction(Function &F) { bool IsWrite; unsigned Alignment; uint64_t TypeSize; - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); // Fill the set of memory operations to instrument. for (auto &BB : F) { @@ -2557,8 +2650,10 @@ bool AddressSanitizer::runOnFunction(Function &F) { continue; // We've seen this temp in the current BB. } } - } else if (ClInvalidPointerPairs && - isInterestingPointerComparisonOrSubtraction(&Inst)) { + } else if (((ClInvalidPointerPairs || ClInvalidPointerCmp) && + isInterestingPointerComparison(&Inst)) || + ((ClInvalidPointerPairs || ClInvalidPointerSub) && + isInterestingPointerSubtraction(&Inst))) { PointerComparisonsOrSubtracts.push_back(&Inst); continue; } else if (isa<MemIntrinsic>(Inst)) { @@ -2569,7 +2664,8 @@ bool AddressSanitizer::runOnFunction(Function &F) { if (CS) { // A call inside BB. TempsToInstrument.clear(); - if (CS.doesNotReturn()) NoReturnCalls.push_back(CS.getInstruction()); + if (CS.doesNotReturn() && !CS->getMetadata("nosanitize")) + NoReturnCalls.push_back(CS.getInstruction()); } if (CallInst *CI = dyn_cast<CallInst>(&Inst)) maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); @@ -2606,7 +2702,7 @@ bool AddressSanitizer::runOnFunction(Function &F) { FunctionStackPoisoner FSP(F, *this); bool ChangedStack = FSP.runOnFunction(); - // We must unpoison the stack before every NoReturn call (throw, _exit, etc). + // We must unpoison the stack before NoReturn calls (throw, _exit, etc). // See e.g. https://github.com/google/sanitizers/issues/37 for (auto CI : NoReturnCalls) { IRBuilder<> IRB(CI); @@ -2643,20 +2739,17 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); for (int i = 0; i <= kMaxAsanStackMallocSizeClass; i++) { std::string Suffix = itostr(i); - AsanStackMallocFunc[i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanStackMallocNameTemplate + Suffix, IntptrTy, - IntptrTy)); - AsanStackFreeFunc[i] = checkSanitizerInterfaceFunction( + AsanStackMallocFunc[i] = M.getOrInsertFunction( + kAsanStackMallocNameTemplate + Suffix, IntptrTy, IntptrTy); + AsanStackFreeFunc[i] = M.getOrInsertFunction(kAsanStackFreeNameTemplate + Suffix, - IRB.getVoidTy(), IntptrTy, IntptrTy)); + IRB.getVoidTy(), IntptrTy, IntptrTy); } if (ASan.UseAfterScope) { - AsanPoisonStackMemoryFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanPoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy)); - AsanUnpoisonStackMemoryFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), - IntptrTy, IntptrTy)); + AsanPoisonStackMemoryFunc = M.getOrInsertFunction( + kAsanPoisonStackMemoryName, IRB.getVoidTy(), IntptrTy, IntptrTy); + AsanUnpoisonStackMemoryFunc = M.getOrInsertFunction( + kAsanUnpoisonStackMemoryName, IRB.getVoidTy(), IntptrTy, IntptrTy); } for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) { @@ -2664,15 +2757,13 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { Name << kAsanSetShadowPrefix; Name << std::setw(2) << std::setfill('0') << std::hex << Val; AsanSetShadowFunc[Val] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy)); + M.getOrInsertFunction(Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy); } - AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy)); - AsanAllocasUnpoisonFunc = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy)); + AsanAllocaPoisonFunc = M.getOrInsertFunction( + kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy); + AsanAllocasUnpoisonFunc = M.getOrInsertFunction( + kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy); } void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, @@ -2958,7 +3049,7 @@ void FunctionStackPoisoner::processStaticAllocas() { Value *FakeStack; Value *LocalStackBase; Value *LocalStackBaseAlloca; - bool Deref; + uint8_t DIExprFlags = DIExpression::ApplyOffset; if (DoStackMalloc) { LocalStackBaseAlloca = @@ -2969,9 +3060,9 @@ void FunctionStackPoisoner::processStaticAllocas() { // void *LocalStackBase = (FakeStack) ? FakeStack : alloca(LocalStackSize); Constant *OptionDetectUseAfterReturn = F.getParent()->getOrInsertGlobal( kAsanOptionDetectUseAfterReturn, IRB.getInt32Ty()); - Value *UseAfterReturnIsEnabled = - IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUseAfterReturn), - Constant::getNullValue(IRB.getInt32Ty())); + Value *UseAfterReturnIsEnabled = IRB.CreateICmpNE( + IRB.CreateLoad(IRB.getInt32Ty(), OptionDetectUseAfterReturn), + Constant::getNullValue(IRB.getInt32Ty())); Instruction *Term = SplitBlockAndInsertIfThen(UseAfterReturnIsEnabled, InsBefore, false); IRBuilder<> IRBIf(Term); @@ -2999,7 +3090,7 @@ void FunctionStackPoisoner::processStaticAllocas() { LocalStackBase = createPHI(IRB, NoFakeStack, AllocaValue, Term, FakeStack); IRB.SetCurrentDebugLocation(EntryDebugLocation); IRB.CreateStore(LocalStackBase, LocalStackBaseAlloca); - Deref = true; + DIExprFlags |= DIExpression::DerefBefore; } else { // void *FakeStack = nullptr; // void *LocalStackBase = alloca(LocalStackSize); @@ -3007,14 +3098,13 @@ void FunctionStackPoisoner::processStaticAllocas() { LocalStackBase = DoDynamicAlloca ? createAllocaForLayout(IRB, L, true) : StaticAlloca; LocalStackBaseAlloca = LocalStackBase; - Deref = false; } // Replace Alloca instructions with base+offset. for (const auto &Desc : SVD) { AllocaInst *AI = Desc.AI; - replaceDbgDeclareForAlloca(AI, LocalStackBaseAlloca, DIB, Deref, - Desc.Offset, DIExpression::NoDeref); + replaceDbgDeclareForAlloca(AI, LocalStackBaseAlloca, DIB, DIExprFlags, + Desc.Offset); Value *NewAllocaPtr = IRB.CreateIntToPtr( IRB.CreateAdd(LocalStackBase, ConstantInt::get(IntptrTy, Desc.Offset)), AI->getType()); @@ -3105,7 +3195,7 @@ void FunctionStackPoisoner::processStaticAllocas() { FakeStack, ConstantInt::get(IntptrTy, ClassSize - ASan.LongSize / 8)); Value *SavedFlagPtr = IRBPoison.CreateLoad( - IRBPoison.CreateIntToPtr(SavedFlagPtrPtr, IntptrPtrTy)); + IntptrTy, IRBPoison.CreateIntToPtr(SavedFlagPtrPtr, IntptrPtrTy)); IRBPoison.CreateStore( Constant::getNullValue(IRBPoison.getInt8Ty()), IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getInt8PtrTy())); @@ -3145,41 +3235,6 @@ void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, // variable may go in and out of scope several times, e.g. in loops). // (3) if we poisoned at least one %alloca in a function, // unpoison the whole stack frame at function exit. - -AllocaInst *FunctionStackPoisoner::findAllocaForValue(Value *V) { - if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) - // We're interested only in allocas we can handle. - return ASan.isInterestingAlloca(*AI) ? AI : nullptr; - // See if we've already calculated (or started to calculate) alloca for a - // given value. - AllocaForValueMapTy::iterator I = AllocaForValue.find(V); - if (I != AllocaForValue.end()) return I->second; - // Store 0 while we're calculating alloca for value V to avoid - // infinite recursion if the value references itself. - AllocaForValue[V] = nullptr; - AllocaInst *Res = nullptr; - if (CastInst *CI = dyn_cast<CastInst>(V)) - Res = findAllocaForValue(CI->getOperand(0)); - else if (PHINode *PN = dyn_cast<PHINode>(V)) { - for (Value *IncValue : PN->incoming_values()) { - // Allow self-referencing phi-nodes. - if (IncValue == PN) continue; - AllocaInst *IncValueAI = findAllocaForValue(IncValue); - // AI for incoming values should exist and should all be equal. - if (IncValueAI == nullptr || (Res != nullptr && IncValueAI != Res)) - return nullptr; - Res = IncValueAI; - } - } else if (GetElementPtrInst *EP = dyn_cast<GetElementPtrInst>(V)) { - Res = findAllocaForValue(EP->getPointerOperand()); - } else { - LLVM_DEBUG(dbgs() << "Alloca search canceled on unknown instruction: " << *V - << "\n"); - } - if (Res) AllocaForValue[V] = Res; - return Res; -} - void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { IRBuilder<> IRB(AI); diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index a0c78e0468c6..4dc9b611c156 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -1,9 +1,8 @@ //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -143,8 +142,9 @@ static void insertBoundsCheck(Value *Or, BuilderTy IRB, GetTrapBBT GetTrapBB) { static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, ScalarEvolution &SE) { const DataLayout &DL = F.getParent()->getDataLayout(); - ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), - /*RoundToAlign=*/true); + ObjectSizeOpts EvalOpts; + EvalOpts.RoundToAlign = true; + ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), EvalOpts); // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory // touching instructions diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index e178ef386e68..971e00041762 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -1,9 +1,8 @@ //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -196,11 +195,10 @@ public: // Sort CFG edges based on its weight. void sortEdgesByWeight() { - std::stable_sort(AllEdges.begin(), AllEdges.end(), - [](const std::unique_ptr<Edge> &Edge1, - const std::unique_ptr<Edge> &Edge2) { - return Edge1->Weight > Edge2->Weight; - }); + llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1, + const std::unique_ptr<Edge> &Edge2) { + return Edge1->Weight > Edge2->Weight; + }); } // Traverse all the edges and compute the Minimum Weight Spanning Tree diff --git a/lib/Transforms/Instrumentation/CGProfile.cpp b/lib/Transforms/Instrumentation/CGProfile.cpp index cdcd01726906..358abab3cceb 100644 --- a/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/lib/Transforms/Instrumentation/CGProfile.cpp @@ -1,9 +1,8 @@ //===-- CGProfile.cpp -----------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 1ada0b713092..3f4f9bc7145d 100644 --- a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -1,9 +1,8 @@ //===-- ControlHeightReduction.cpp - Control Height Reduction -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -547,19 +546,25 @@ static std::set<Value *> getBaseValues(Value *V, static bool checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, DenseSet<Instruction *> &Unhoistables, - DenseSet<Instruction *> *HoistStops) { + DenseSet<Instruction *> *HoistStops, + DenseMap<Instruction *, bool> &Visited) { assert(InsertPoint && "Null InsertPoint"); if (auto *I = dyn_cast<Instruction>(V)) { + if (Visited.count(I)) { + return Visited[I]; + } assert(DT.getNode(I->getParent()) && "DT must contain I's parent block"); assert(DT.getNode(InsertPoint->getParent()) && "DT must contain Destination"); if (Unhoistables.count(I)) { // Don't hoist if they are not to be hoisted. + Visited[I] = false; return false; } if (DT.dominates(I, InsertPoint)) { // We are already above the insert point. Stop here. if (HoistStops) HoistStops->insert(I); + Visited[I] = true; return true; } // We aren't not above the insert point, check if we can hoist it above the @@ -569,7 +574,8 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, DenseSet<Instruction *> OpsHoistStops; bool AllOpsHoisted = true; for (Value *Op : I->operands()) { - if (!checkHoistValue(Op, InsertPoint, DT, Unhoistables, &OpsHoistStops)) { + if (!checkHoistValue(Op, InsertPoint, DT, Unhoistables, &OpsHoistStops, + Visited)) { AllOpsHoisted = false; break; } @@ -578,9 +584,11 @@ checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, CHR_DEBUG(dbgs() << "checkHoistValue " << *I << "\n"); if (HoistStops) HoistStops->insert(OpsHoistStops.begin(), OpsHoistStops.end()); + Visited[I] = true; return true; } } + Visited[I] = false; return false; } // Non-instructions are considered hoistable. @@ -893,8 +901,9 @@ void CHR::checkScopeHoistable(CHRScope *Scope) { ++it; continue; } + DenseMap<Instruction *, bool> Visited; bool IsHoistable = checkHoistValue(SI->getCondition(), InsertPoint, - DT, Unhoistables, nullptr); + DT, Unhoistables, nullptr, Visited); if (!IsHoistable) { CHR_DEBUG(dbgs() << "Dropping select " << *SI << "\n"); ORE.emit([&]() { @@ -913,8 +922,9 @@ void CHR::checkScopeHoistable(CHRScope *Scope) { InsertPoint = getBranchInsertPoint(RI); CHR_DEBUG(dbgs() << "InsertPoint " << *InsertPoint << "\n"); if (RI.HasBranch && InsertPoint != Branch) { + DenseMap<Instruction *, bool> Visited; bool IsHoistable = checkHoistValue(Branch->getCondition(), InsertPoint, - DT, Unhoistables, nullptr); + DT, Unhoistables, nullptr, Visited); if (!IsHoistable) { // If the branch isn't hoistable, drop the selects in the entry // block, preferring the branch, which makes the branch the hoist @@ -945,15 +955,17 @@ void CHR::checkScopeHoistable(CHRScope *Scope) { if (RI.HasBranch) { assert(!DT.dominates(Branch, InsertPoint) && "Branch can't be already above the hoist point"); + DenseMap<Instruction *, bool> Visited; assert(checkHoistValue(Branch->getCondition(), InsertPoint, - DT, Unhoistables, nullptr) && + DT, Unhoistables, nullptr, Visited) && "checkHoistValue for branch"); } for (auto *SI : Selects) { assert(!DT.dominates(SI, InsertPoint) && "SI can't be already above the hoist point"); + DenseMap<Instruction *, bool> Visited; assert(checkHoistValue(SI->getCondition(), InsertPoint, DT, - Unhoistables, nullptr) && + Unhoistables, nullptr, Visited) && "checkHoistValue for selects"); } CHR_DEBUG(dbgs() << "Result\n"); @@ -1054,7 +1066,8 @@ static bool shouldSplit(Instruction *InsertPoint, assert(InsertPoint && "Null InsertPoint"); // If any of Bases isn't hoistable to the hoist point, split. for (Value *V : ConditionValues) { - if (!checkHoistValue(V, InsertPoint, DT, Unhoistables, nullptr)) { + DenseMap<Instruction *, bool> Visited; + if (!checkHoistValue(V, InsertPoint, DT, Unhoistables, nullptr, Visited)) { CHR_DEBUG(dbgs() << "Split. checkHoistValue false " << *V << "\n"); return true; // Not hoistable, split. } @@ -1383,8 +1396,9 @@ void CHR::setCHRRegions(CHRScope *Scope, CHRScope *OutermostScope) { "Must be truthy or falsy"); auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); // Note checkHoistValue fills in HoistStops. + DenseMap<Instruction *, bool> Visited; bool IsHoistable = checkHoistValue(BI->getCondition(), InsertPoint, DT, - Unhoistables, &HoistStops); + Unhoistables, &HoistStops, Visited); assert(IsHoistable && "Must be hoistable"); (void)(IsHoistable); // Unused in release build IsHoisted = true; @@ -1394,8 +1408,9 @@ void CHR::setCHRRegions(CHRScope *Scope, CHRScope *OutermostScope) { OutermostScope->FalseBiasedSelects.count(SI) > 0) && "Must be true or false biased"); // Note checkHoistValue fills in HoistStops. + DenseMap<Instruction *, bool> Visited; bool IsHoistable = checkHoistValue(SI->getCondition(), InsertPoint, DT, - Unhoistables, &HoistStops); + Unhoistables, &HoistStops, Visited); assert(IsHoistable && "Must be hoistable"); (void)(IsHoistable); // Unused in release build IsHoisted = true; @@ -1417,7 +1432,7 @@ void CHR::sortScopes(SmallVectorImpl<CHRScope *> &Input, SmallVectorImpl<CHRScope *> &Output) { Output.resize(Input.size()); llvm::copy(Input, Output.begin()); - std::stable_sort(Output.begin(), Output.end(), CHRScopeSorter); + llvm::stable_sort(Output, CHRScopeSorter); } // Return true if V is already hoisted or was hoisted (along with its operands) @@ -1425,7 +1440,8 @@ void CHR::sortScopes(SmallVectorImpl<CHRScope *> &Input, static void hoistValue(Value *V, Instruction *HoistPoint, Region *R, HoistStopMapTy &HoistStopMap, DenseSet<Instruction *> &HoistedSet, - DenseSet<PHINode *> &TrivialPHIs) { + DenseSet<PHINode *> &TrivialPHIs, + DominatorTree &DT) { auto IT = HoistStopMap.find(R); assert(IT != HoistStopMap.end() && "Region must be in hoist stop map"); DenseSet<Instruction *> &HoistStops = IT->second; @@ -1445,8 +1461,21 @@ static void hoistValue(Value *V, Instruction *HoistPoint, Region *R, // Already hoisted, return. return; assert(isHoistableInstructionType(I) && "Unhoistable instruction type"); + assert(DT.getNode(I->getParent()) && "DT must contain I's block"); + assert(DT.getNode(HoistPoint->getParent()) && + "DT must contain HoistPoint block"); + if (DT.dominates(I, HoistPoint)) + // We are already above the hoist point. Stop here. This may be necessary + // when multiple scopes would independently hoist the same + // instruction. Since an outer (dominating) scope would hoist it to its + // entry before an inner (dominated) scope would to its entry, the inner + // scope may see the instruction already hoisted, in which case it + // potentially wrong for the inner scope to hoist it and could cause bad + // IR (non-dominating def), but safe to skip hoisting it instead because + // it's already in a block that dominates the inner scope. + return; for (Value *Op : I->operands()) { - hoistValue(Op, HoistPoint, R, HoistStopMap, HoistedSet, TrivialPHIs); + hoistValue(Op, HoistPoint, R, HoistStopMap, HoistedSet, TrivialPHIs, DT); } I->moveBefore(HoistPoint); HoistedSet.insert(I); @@ -1457,7 +1486,8 @@ static void hoistValue(Value *V, Instruction *HoistPoint, Region *R, // Hoist the dependent condition values of the branches and the selects in the // scope to the insert point. static void hoistScopeConditions(CHRScope *Scope, Instruction *HoistPoint, - DenseSet<PHINode *> &TrivialPHIs) { + DenseSet<PHINode *> &TrivialPHIs, + DominatorTree &DT) { DenseSet<Instruction *> HoistedSet; for (const RegInfo &RI : Scope->CHRRegions) { Region *R = RI.R; @@ -1466,7 +1496,7 @@ static void hoistScopeConditions(CHRScope *Scope, Instruction *HoistPoint, if (RI.HasBranch && (IsTrueBiased || IsFalseBiased)) { auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); hoistValue(BI->getCondition(), HoistPoint, R, Scope->HoistStopMap, - HoistedSet, TrivialPHIs); + HoistedSet, TrivialPHIs, DT); } for (SelectInst *SI : RI.Selects) { bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI); @@ -1474,7 +1504,7 @@ static void hoistScopeConditions(CHRScope *Scope, Instruction *HoistPoint, if (!(IsTrueBiased || IsFalseBiased)) continue; hoistValue(SI->getCondition(), HoistPoint, R, Scope->HoistStopMap, - HoistedSet, TrivialPHIs); + HoistedSet, TrivialPHIs, DT); } } } @@ -1708,7 +1738,7 @@ void CHR::transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs) { #endif // Hoist the conditional values of the branches/selects. - hoistScopeConditions(Scope, PreEntryBlock->getTerminator(), TrivialPHIs); + hoistScopeConditions(Scope, PreEntryBlock->getTerminator(), TrivialPHIs, DT); #ifndef NDEBUG assertBranchOrSelectConditionHoisted(Scope, PreEntryBlock); diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 4c3c6c9added..2279c1bcb6a8 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -1,9 +1,8 @@ //===- DataFlowSanitizer.cpp - dynamic data flow analysis -----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -333,6 +332,8 @@ class DataFlowSanitizer : public ModulePass { Constant *RetvalTLS; void *(*GetArgTLSPtr)(); void *(*GetRetvalTLSPtr)(); + FunctionType *GetArgTLSTy; + FunctionType *GetRetvalTLSTy; Constant *GetArgTLS; Constant *GetRetvalTLS; Constant *ExternalShadowMask; @@ -342,13 +343,13 @@ class DataFlowSanitizer : public ModulePass { FunctionType *DFSanSetLabelFnTy; FunctionType *DFSanNonzeroLabelFnTy; FunctionType *DFSanVarargWrapperFnTy; - Constant *DFSanUnionFn; - Constant *DFSanCheckedUnionFn; - Constant *DFSanUnionLoadFn; - Constant *DFSanUnimplementedFn; - Constant *DFSanSetLabelFn; - Constant *DFSanNonzeroLabelFn; - Constant *DFSanVarargWrapperFn; + FunctionCallee DFSanUnionFn; + FunctionCallee DFSanCheckedUnionFn; + FunctionCallee DFSanUnionLoadFn; + FunctionCallee DFSanUnimplementedFn; + FunctionCallee DFSanSetLabelFn; + FunctionCallee DFSanNonzeroLabelFn; + FunctionCallee DFSanVarargWrapperFn; MDNode *ColdCallWeights; DFSanABIList ABIList; DenseMap<Value *, Function *> UnwrappedFnMap; @@ -436,6 +437,7 @@ public: } void visitOperandShadowInst(Instruction &I); + void visitUnaryOperator(UnaryOperator &UO); void visitBinaryOperator(BinaryOperator &BO); void visitCastInst(CastInst &CI); void visitCmpInst(CmpInst &CI); @@ -581,17 +583,17 @@ bool DataFlowSanitizer::doInitialization(Module &M) { if (GetArgTLSPtr) { Type *ArgTLSTy = ArrayType::get(ShadowTy, 64); ArgTLS = nullptr; + GetArgTLSTy = FunctionType::get(PointerType::getUnqual(ArgTLSTy), false); GetArgTLS = ConstantExpr::getIntToPtr( ConstantInt::get(IntptrTy, uintptr_t(GetArgTLSPtr)), - PointerType::getUnqual( - FunctionType::get(PointerType::getUnqual(ArgTLSTy), false))); + PointerType::getUnqual(GetArgTLSTy)); } if (GetRetvalTLSPtr) { RetvalTLS = nullptr; + GetRetvalTLSTy = FunctionType::get(PointerType::getUnqual(ShadowTy), false); GetRetvalTLS = ConstantExpr::getIntToPtr( ConstantInt::get(IntptrTy, uintptr_t(GetRetvalTLSPtr)), - PointerType::getUnqual( - FunctionType::get(PointerType::getUnqual(ShadowTy), false))); + PointerType::getUnqual(GetRetvalTLSTy)); } ColdCallWeights = MDBuilder(*Ctx).createBranchWeights(1, 1000); @@ -678,8 +680,8 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, StringRef FName) { FunctionType *FTT = getTrampolineFunctionType(FT); - Constant *C = Mod->getOrInsertFunction(FName, FTT); - Function *F = dyn_cast<Function>(C); + FunctionCallee C = Mod->getOrInsertFunction(FName, FTT); + Function *F = dyn_cast<Function>(C.getCallee()); if (F && F->isDeclaration()) { F->setLinkage(GlobalValue::LinkOnceODRLinkage); BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", F); @@ -687,7 +689,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, Function::arg_iterator AI = F->arg_begin(); ++AI; for (unsigned N = FT->getNumParams(); N != 0; ++AI, --N) Args.push_back(&*AI); - CallInst *CI = CallInst::Create(&*F->arg_begin(), Args, "", BB); + CallInst *CI = CallInst::Create(FT, &*F->arg_begin(), Args, "", BB); ReturnInst *RI; if (FT->getReturnType()->isVoidTy()) RI = ReturnInst::Create(*Ctx, BB); @@ -704,7 +706,7 @@ Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, &*std::prev(F->arg_end()), RI); } - return C; + return cast<Constant>(C.getCallee()); } bool DataFlowSanitizer::runOnModule(Module &M) { @@ -726,35 +728,51 @@ bool DataFlowSanitizer::runOnModule(Module &M) { ExternalShadowMask = Mod->getOrInsertGlobal(kDFSanExternShadowPtrMask, IntptrTy); - DFSanUnionFn = Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy); - if (Function *F = dyn_cast<Function>(DFSanUnionFn)) { - F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); - F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); - F->addParamAttr(0, Attribute::ZExt); - F->addParamAttr(1, Attribute::ZExt); + { + AttributeList AL; + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::ReadNone); + AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, + Attribute::ZExt); + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); + DFSanUnionFn = + Mod->getOrInsertFunction("__dfsan_union", DFSanUnionFnTy, AL); } - DFSanCheckedUnionFn = Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy); - if (Function *F = dyn_cast<Function>(DFSanCheckedUnionFn)) { - F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); - F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); - F->addParamAttr(0, Attribute::ZExt); - F->addParamAttr(1, Attribute::ZExt); + + { + AttributeList AL; + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::ReadNone); + AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, + Attribute::ZExt); + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + AL = AL.addParamAttribute(M.getContext(), 1, Attribute::ZExt); + DFSanCheckedUnionFn = + Mod->getOrInsertFunction("dfsan_union", DFSanUnionFnTy, AL); } - DFSanUnionLoadFn = - Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy); - if (Function *F = dyn_cast<Function>(DFSanUnionLoadFn)) { - F->addAttribute(AttributeList::FunctionIndex, Attribute::NoUnwind); - F->addAttribute(AttributeList::FunctionIndex, Attribute::ReadOnly); - F->addAttribute(AttributeList::ReturnIndex, Attribute::ZExt); + { + AttributeList AL; + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::NoUnwind); + AL = AL.addAttribute(M.getContext(), AttributeList::FunctionIndex, + Attribute::ReadOnly); + AL = AL.addAttribute(M.getContext(), AttributeList::ReturnIndex, + Attribute::ZExt); + DFSanUnionLoadFn = + Mod->getOrInsertFunction("__dfsan_union_load", DFSanUnionLoadFnTy, AL); } DFSanUnimplementedFn = Mod->getOrInsertFunction("__dfsan_unimplemented", DFSanUnimplementedFnTy); - DFSanSetLabelFn = - Mod->getOrInsertFunction("__dfsan_set_label", DFSanSetLabelFnTy); - if (Function *F = dyn_cast<Function>(DFSanSetLabelFn)) { - F->addParamAttr(0, Attribute::ZExt); + { + AttributeList AL; + AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); + DFSanSetLabelFn = + Mod->getOrInsertFunction("__dfsan_set_label", DFSanSetLabelFnTy, AL); } DFSanNonzeroLabelFn = Mod->getOrInsertFunction("__dfsan_nonzero_label", DFSanNonzeroLabelFnTy); @@ -765,13 +783,13 @@ bool DataFlowSanitizer::runOnModule(Module &M) { SmallPtrSet<Function *, 2> FnsWithNativeABI; for (Function &i : M) { if (!i.isIntrinsic() && - &i != DFSanUnionFn && - &i != DFSanCheckedUnionFn && - &i != DFSanUnionLoadFn && - &i != DFSanUnimplementedFn && - &i != DFSanSetLabelFn && - &i != DFSanNonzeroLabelFn && - &i != DFSanVarargWrapperFn) + &i != DFSanUnionFn.getCallee()->stripPointerCasts() && + &i != DFSanCheckedUnionFn.getCallee()->stripPointerCasts() && + &i != DFSanUnionLoadFn.getCallee()->stripPointerCasts() && + &i != DFSanUnimplementedFn.getCallee()->stripPointerCasts() && + &i != DFSanSetLabelFn.getCallee()->stripPointerCasts() && + &i != DFSanNonzeroLabelFn.getCallee()->stripPointerCasts() && + &i != DFSanVarargWrapperFn.getCallee()->stripPointerCasts()) FnsToInstrument.push_back(&i); } @@ -982,7 +1000,7 @@ Value *DFSanFunction::getArgTLSPtr() { return ArgTLSPtr = DFS.ArgTLS; IRBuilder<> IRB(&F->getEntryBlock().front()); - return ArgTLSPtr = IRB.CreateCall(DFS.GetArgTLS, {}); + return ArgTLSPtr = IRB.CreateCall(DFS.GetArgTLSTy, DFS.GetArgTLS, {}); } Value *DFSanFunction::getRetvalTLS() { @@ -992,12 +1010,14 @@ Value *DFSanFunction::getRetvalTLS() { return RetvalTLSPtr = DFS.RetvalTLS; IRBuilder<> IRB(&F->getEntryBlock().front()); - return RetvalTLSPtr = IRB.CreateCall(DFS.GetRetvalTLS, {}); + return RetvalTLSPtr = + IRB.CreateCall(DFS.GetRetvalTLSTy, DFS.GetRetvalTLS, {}); } Value *DFSanFunction::getArgTLS(unsigned Idx, Instruction *Pos) { IRBuilder<> IRB(Pos); - return IRB.CreateConstGEP2_64(getArgTLSPtr(), 0, Idx); + return IRB.CreateConstGEP2_64(ArrayType::get(DFS.ShadowTy, 64), + getArgTLSPtr(), 0, Idx); } Value *DFSanFunction::getShadow(Value *V) { @@ -1015,7 +1035,8 @@ Value *DFSanFunction::getShadow(Value *V) { DFS.ArgTLS ? &*F->getEntryBlock().begin() : cast<Instruction>(ArgTLSPtr)->getNextNode(); IRBuilder<> IRB(ArgTLSPos); - Shadow = IRB.CreateLoad(getArgTLS(A->getArgNo(), ArgTLSPos)); + Shadow = + IRB.CreateLoad(DFS.ShadowTy, getArgTLS(A->getArgNo(), ArgTLSPos)); break; } case DataFlowSanitizer::IA_Args: { @@ -1165,15 +1186,15 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, const auto i = AllocaShadowMap.find(AI); if (i != AllocaShadowMap.end()) { IRBuilder<> IRB(Pos); - return IRB.CreateLoad(i->second); + return IRB.CreateLoad(DFS.ShadowTy, i->second); } } uint64_t ShadowAlign = Align * DFS.ShadowWidth / 8; - SmallVector<Value *, 2> Objs; + SmallVector<const Value *, 2> Objs; GetUnderlyingObjects(Addr, Objs, Pos->getModule()->getDataLayout()); bool AllConstants = true; - for (Value *Obj : Objs) { + for (const Value *Obj : Objs) { if (isa<Function>(Obj) || isa<BlockAddress>(Obj)) continue; if (isa<GlobalVariable>(Obj) && cast<GlobalVariable>(Obj)->isConstant()) @@ -1190,7 +1211,7 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, case 0: return DFS.ZeroShadow; case 1: { - LoadInst *LI = new LoadInst(ShadowAddr, "", Pos); + LoadInst *LI = new LoadInst(DFS.ShadowTy, ShadowAddr, "", Pos); LI->setAlignment(ShadowAlign); return LI; } @@ -1198,8 +1219,9 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, IRBuilder<> IRB(Pos); Value *ShadowAddr1 = IRB.CreateGEP(DFS.ShadowTy, ShadowAddr, ConstantInt::get(DFS.IntptrTy, 1)); - return combineShadows(IRB.CreateAlignedLoad(ShadowAddr, ShadowAlign), - IRB.CreateAlignedLoad(ShadowAddr1, ShadowAlign), Pos); + return combineShadows( + IRB.CreateAlignedLoad(DFS.ShadowTy, ShadowAddr, ShadowAlign), + IRB.CreateAlignedLoad(DFS.ShadowTy, ShadowAddr1, ShadowAlign), Pos); } } if (!AvoidNewBlocks && Size % (64 / DFS.ShadowWidth) == 0) { @@ -1218,7 +1240,8 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, IRBuilder<> IRB(Pos); Value *WideAddr = IRB.CreateBitCast(ShadowAddr, Type::getInt64PtrTy(*DFS.Ctx)); - Value *WideShadow = IRB.CreateAlignedLoad(WideAddr, ShadowAlign); + Value *WideShadow = + IRB.CreateAlignedLoad(IRB.getInt64Ty(), WideAddr, ShadowAlign); Value *TruncShadow = IRB.CreateTrunc(WideShadow, DFS.ShadowTy); Value *ShlShadow = IRB.CreateShl(WideShadow, DFS.ShadowWidth); Value *ShrShadow = IRB.CreateLShr(WideShadow, 64 - DFS.ShadowWidth); @@ -1251,7 +1274,8 @@ Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, IRBuilder<> NextIRB(NextBB); WideAddr = NextIRB.CreateGEP(Type::getInt64Ty(*DFS.Ctx), WideAddr, ConstantInt::get(DFS.IntptrTy, 1)); - Value *NextWideShadow = NextIRB.CreateAlignedLoad(WideAddr, ShadowAlign); + Value *NextWideShadow = NextIRB.CreateAlignedLoad(NextIRB.getInt64Ty(), + WideAddr, ShadowAlign); ShadowsEq = NextIRB.CreateICmpEQ(WideShadow, NextWideShadow); LastBr->setSuccessor(0, NextBB); LastBr = NextIRB.CreateCondBr(ShadowsEq, FallbackBB, FallbackBB); @@ -1375,6 +1399,10 @@ void DFSanVisitor::visitStoreInst(StoreInst &SI) { DFSF.storeShadow(SI.getPointerOperand(), Size, Align, Shadow, &SI); } +void DFSanVisitor::visitUnaryOperator(UnaryOperator &UO) { + visitOperandShadowInst(UO); +} + void DFSanVisitor::visitBinaryOperator(BinaryOperator &BO) { visitOperandShadowInst(BO); } @@ -1470,7 +1498,7 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { DestShadow = IRB.CreateBitCast(DestShadow, Int8Ptr); SrcShadow = IRB.CreateBitCast(SrcShadow, Int8Ptr); auto *MTI = cast<MemTransferInst>( - IRB.CreateCall(I.getCalledValue(), + IRB.CreateCall(I.getFunctionType(), I.getCalledValue(), {DestShadow, SrcShadow, LenShadow, I.getVolatileCst()})); if (ClPreserveAlignment) { MTI->setDestAlignment(I.getDestAlignment() * (DFSF.DFS.ShadowWidth / 8)); @@ -1513,7 +1541,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { // Calls to this function are synthesized in wrappers, and we shouldn't // instrument them. - if (F == DFSF.DFS.DFSanVarargWrapperFn) + if (F == DFSF.DFS.DFSanVarargWrapperFn.getCallee()->stripPointerCasts()) return; IRBuilder<> IRB(CS.getInstruction()); @@ -1546,9 +1574,9 @@ void DFSanVisitor::visitCallSite(CallSite CS) { TransformedFunction CustomFn = DFSF.DFS.getCustomFunctionType(FT); std::string CustomFName = "__dfsw_"; CustomFName += F->getName(); - Constant *CustomF = DFSF.DFS.Mod->getOrInsertFunction( + FunctionCallee CustomF = DFSF.DFS.Mod->getOrInsertFunction( CustomFName, CustomFn.TransformedType); - if (Function *CustomFn = dyn_cast<Function>(CustomF)) { + if (Function *CustomFn = dyn_cast<Function>(CustomF.getCallee())) { CustomFn->copyAttributesFrom(F); // Custom functions returning non-void will write to the return label. @@ -1628,7 +1656,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) { } if (!FT->getReturnType()->isVoidTy()) { - LoadInst *LabelLoad = IRB.CreateLoad(DFSF.LabelReturnAlloca); + LoadInst *LabelLoad = + IRB.CreateLoad(DFSF.DFS.ShadowTy, DFSF.LabelReturnAlloca); DFSF.setShadow(CustomCI, LabelLoad); } @@ -1666,7 +1695,7 @@ void DFSanVisitor::visitCallSite(CallSite CS) { if (DFSF.DFS.getInstrumentedABI() == DataFlowSanitizer::IA_TLS) { IRBuilder<> NextIRB(Next); - LoadInst *LI = NextIRB.CreateLoad(DFSF.getRetvalTLS()); + LoadInst *LI = NextIRB.CreateLoad(DFSF.DFS.ShadowTy, DFSF.getRetvalTLS()); DFSF.SkipInsts.insert(LI); DFSF.setShadow(CS.getInstruction(), LI); DFSF.NonZeroChecks.push_back(LI); @@ -1706,10 +1735,10 @@ void DFSanVisitor::visitCallSite(CallSite CS) { CallSite NewCS; if (InvokeInst *II = dyn_cast<InvokeInst>(CS.getInstruction())) { - NewCS = IRB.CreateInvoke(Func, II->getNormalDest(), II->getUnwindDest(), - Args); + NewCS = IRB.CreateInvoke(NewFT, Func, II->getNormalDest(), + II->getUnwindDest(), Args); } else { - NewCS = IRB.CreateCall(Func, Args); + NewCS = IRB.CreateCall(NewFT, Func, Args); } NewCS.setCallingConv(CS.getCallingConv()); NewCS.setAttributes(CS.getAttributes().removeAttributes( diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp deleted file mode 100644 index db438e78ded9..000000000000 --- a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ /dev/null @@ -1,900 +0,0 @@ -//===-- EfficiencySanitizer.cpp - performance tuner -----------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file is a part of EfficiencySanitizer, a family of performance tuners -// that detects multiple performance issues via separate sub-tools. -// -// The instrumentation phase is straightforward: -// - Take action on every memory access: either inlined instrumentation, -// or Inserted calls to our run-time library. -// - Optimizations may apply to avoid instrumenting some of the accesses. -// - Turn mem{set,cpy,move} instrinsics into library calls. -// The rest is handled by the run-time library. -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" - -using namespace llvm; - -#define DEBUG_TYPE "esan" - -// The tool type must be just one of these ClTool* options, as the tools -// cannot be combined due to shadow memory constraints. -static cl::opt<bool> - ClToolCacheFrag("esan-cache-frag", cl::init(false), - cl::desc("Detect data cache fragmentation"), cl::Hidden); -static cl::opt<bool> - ClToolWorkingSet("esan-working-set", cl::init(false), - cl::desc("Measure the working set size"), cl::Hidden); -// Each new tool will get its own opt flag here. -// These are converted to EfficiencySanitizerOptions for use -// in the code. - -static cl::opt<bool> ClInstrumentLoadsAndStores( - "esan-instrument-loads-and-stores", cl::init(true), - cl::desc("Instrument loads and stores"), cl::Hidden); -static cl::opt<bool> ClInstrumentMemIntrinsics( - "esan-instrument-memintrinsics", cl::init(true), - cl::desc("Instrument memintrinsics (memset/memcpy/memmove)"), cl::Hidden); -static cl::opt<bool> ClInstrumentFastpath( - "esan-instrument-fastpath", cl::init(true), - cl::desc("Instrument fastpath"), cl::Hidden); -static cl::opt<bool> ClAuxFieldInfo( - "esan-aux-field-info", cl::init(true), - cl::desc("Generate binary with auxiliary struct field information"), - cl::Hidden); - -// Experiments show that the performance difference can be 2x or more, -// and accuracy loss is typically negligible, so we turn this on by default. -static cl::opt<bool> ClAssumeIntraCacheLine( - "esan-assume-intra-cache-line", cl::init(true), - cl::desc("Assume each memory access touches just one cache line, for " - "better performance but with a potential loss of accuracy."), - cl::Hidden); - -STATISTIC(NumInstrumentedLoads, "Number of instrumented loads"); -STATISTIC(NumInstrumentedStores, "Number of instrumented stores"); -STATISTIC(NumFastpaths, "Number of instrumented fastpaths"); -STATISTIC(NumAccessesWithIrregularSize, - "Number of accesses with a size outside our targeted callout sizes"); -STATISTIC(NumIgnoredStructs, "Number of ignored structs"); -STATISTIC(NumIgnoredGEPs, "Number of ignored GEP instructions"); -STATISTIC(NumInstrumentedGEPs, "Number of instrumented GEP instructions"); -STATISTIC(NumAssumedIntraCacheLine, - "Number of accesses assumed to be intra-cache-line"); - -static const uint64_t EsanCtorAndDtorPriority = 0; -static const char *const EsanModuleCtorName = "esan.module_ctor"; -static const char *const EsanModuleDtorName = "esan.module_dtor"; -static const char *const EsanInitName = "__esan_init"; -static const char *const EsanExitName = "__esan_exit"; - -// We need to specify the tool to the runtime earlier than -// the ctor is called in some cases, so we set a global variable. -static const char *const EsanWhichToolName = "__esan_which_tool"; - -// We must keep these Shadow* constants consistent with the esan runtime. -// FIXME: Try to place these shadow constants, the names of the __esan_* -// interface functions, and the ToolType enum into a header shared between -// llvm and compiler-rt. -struct ShadowMemoryParams { - uint64_t ShadowMask; - uint64_t ShadowOffs[3]; -}; - -static const ShadowMemoryParams ShadowParams47 = { - 0x00000fffffffffffull, - { - 0x0000130000000000ull, 0x0000220000000000ull, 0x0000440000000000ull, - }}; - -static const ShadowMemoryParams ShadowParams40 = { - 0x0fffffffffull, - { - 0x1300000000ull, 0x2200000000ull, 0x4400000000ull, - }}; - -// This array is indexed by the ToolType enum. -static const int ShadowScale[] = { - 0, // ESAN_None. - 2, // ESAN_CacheFrag: 4B:1B, so 4 to 1 == >>2. - 6, // ESAN_WorkingSet: 64B:1B, so 64 to 1 == >>6. -}; - -// MaxStructCounterNameSize is a soft size limit to avoid insanely long -// names for those extremely large structs. -static const unsigned MaxStructCounterNameSize = 512; - -namespace { - -static EfficiencySanitizerOptions -OverrideOptionsFromCL(EfficiencySanitizerOptions Options) { - if (ClToolCacheFrag) - Options.ToolType = EfficiencySanitizerOptions::ESAN_CacheFrag; - else if (ClToolWorkingSet) - Options.ToolType = EfficiencySanitizerOptions::ESAN_WorkingSet; - - // Direct opt invocation with no params will have the default ESAN_None. - // We run the default tool in that case. - if (Options.ToolType == EfficiencySanitizerOptions::ESAN_None) - Options.ToolType = EfficiencySanitizerOptions::ESAN_CacheFrag; - - return Options; -} - -/// EfficiencySanitizer: instrument each module to find performance issues. -class EfficiencySanitizer : public ModulePass { -public: - EfficiencySanitizer( - const EfficiencySanitizerOptions &Opts = EfficiencySanitizerOptions()) - : ModulePass(ID), Options(OverrideOptionsFromCL(Opts)) {} - StringRef getPassName() const override; - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnModule(Module &M) override; - static char ID; - -private: - bool initOnModule(Module &M); - void initializeCallbacks(Module &M); - bool shouldIgnoreStructType(StructType *StructTy); - void createStructCounterName( - StructType *StructTy, SmallString<MaxStructCounterNameSize> &NameStr); - void createCacheFragAuxGV( - Module &M, const DataLayout &DL, StructType *StructTy, - GlobalVariable *&TypeNames, GlobalVariable *&Offsets, GlobalVariable *&Size); - GlobalVariable *createCacheFragInfoGV(Module &M, const DataLayout &DL, - Constant *UnitName); - Constant *createEsanInitToolInfoArg(Module &M, const DataLayout &DL); - void createDestructor(Module &M, Constant *ToolInfoArg); - bool runOnFunction(Function &F, Module &M); - bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL); - bool instrumentMemIntrinsic(MemIntrinsic *MI); - bool instrumentGetElementPtr(Instruction *I, Module &M); - bool insertCounterUpdate(Instruction *I, StructType *StructTy, - unsigned CounterIdx); - unsigned getFieldCounterIdx(StructType *StructTy) { - return 0; - } - unsigned getArrayCounterIdx(StructType *StructTy) { - return StructTy->getNumElements(); - } - unsigned getStructCounterSize(StructType *StructTy) { - // The struct counter array includes: - // - one counter for each struct field, - // - one counter for the struct access within an array. - return (StructTy->getNumElements()/*field*/ + 1/*array*/); - } - bool shouldIgnoreMemoryAccess(Instruction *I); - int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL); - Value *appToShadow(Value *Shadow, IRBuilder<> &IRB); - bool instrumentFastpath(Instruction *I, const DataLayout &DL, bool IsStore, - Value *Addr, unsigned Alignment); - // Each tool has its own fastpath routine: - bool instrumentFastpathCacheFrag(Instruction *I, const DataLayout &DL, - Value *Addr, unsigned Alignment); - bool instrumentFastpathWorkingSet(Instruction *I, const DataLayout &DL, - Value *Addr, unsigned Alignment); - - EfficiencySanitizerOptions Options; - LLVMContext *Ctx; - Type *IntptrTy; - // Our slowpath involves callouts to the runtime library. - // Access sizes are powers of two: 1, 2, 4, 8, 16. - static const size_t NumberOfAccessSizes = 5; - Function *EsanAlignedLoad[NumberOfAccessSizes]; - Function *EsanAlignedStore[NumberOfAccessSizes]; - Function *EsanUnalignedLoad[NumberOfAccessSizes]; - Function *EsanUnalignedStore[NumberOfAccessSizes]; - // For irregular sizes of any alignment: - Function *EsanUnalignedLoadN, *EsanUnalignedStoreN; - Function *MemmoveFn, *MemcpyFn, *MemsetFn; - Function *EsanCtorFunction; - Function *EsanDtorFunction; - // Remember the counter variable for each struct type to avoid - // recomputing the variable name later during instrumentation. - std::map<Type *, GlobalVariable *> StructTyMap; - ShadowMemoryParams ShadowParams; -}; -} // namespace - -char EfficiencySanitizer::ID = 0; -INITIALIZE_PASS_BEGIN( - EfficiencySanitizer, "esan", - "EfficiencySanitizer: finds performance issues.", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END( - EfficiencySanitizer, "esan", - "EfficiencySanitizer: finds performance issues.", false, false) - -StringRef EfficiencySanitizer::getPassName() const { - return "EfficiencySanitizer"; -} - -void EfficiencySanitizer::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<TargetLibraryInfoWrapperPass>(); -} - -ModulePass * -llvm::createEfficiencySanitizerPass(const EfficiencySanitizerOptions &Options) { - return new EfficiencySanitizer(Options); -} - -void EfficiencySanitizer::initializeCallbacks(Module &M) { - IRBuilder<> IRB(M.getContext()); - // Initialize the callbacks. - for (size_t Idx = 0; Idx < NumberOfAccessSizes; ++Idx) { - const unsigned ByteSize = 1U << Idx; - std::string ByteSizeStr = utostr(ByteSize); - // We'll inline the most common (i.e., aligned and frequent sizes) - // load + store instrumentation: these callouts are for the slowpath. - SmallString<32> AlignedLoadName("__esan_aligned_load" + ByteSizeStr); - EsanAlignedLoad[Idx] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AlignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy())); - SmallString<32> AlignedStoreName("__esan_aligned_store" + ByteSizeStr); - EsanAlignedStore[Idx] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AlignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy())); - SmallString<32> UnalignedLoadName("__esan_unaligned_load" + ByteSizeStr); - EsanUnalignedLoad[Idx] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedLoadName, IRB.getVoidTy(), IRB.getInt8PtrTy())); - SmallString<32> UnalignedStoreName("__esan_unaligned_store" + ByteSizeStr); - EsanUnalignedStore[Idx] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedStoreName, IRB.getVoidTy(), IRB.getInt8PtrTy())); - } - EsanUnalignedLoadN = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__esan_unaligned_loadN", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy)); - EsanUnalignedStoreN = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__esan_unaligned_storeN", IRB.getVoidTy(), - IRB.getInt8PtrTy(), IntptrTy)); - MemmoveFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy)); - MemcpyFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy)); - MemsetFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy)); -} - -bool EfficiencySanitizer::shouldIgnoreStructType(StructType *StructTy) { - if (StructTy == nullptr || StructTy->isOpaque() /* no struct body */) - return true; - return false; -} - -void EfficiencySanitizer::createStructCounterName( - StructType *StructTy, SmallString<MaxStructCounterNameSize> &NameStr) { - // Append NumFields and field type ids to avoid struct conflicts - // with the same name but different fields. - if (StructTy->hasName()) - NameStr += StructTy->getName(); - else - NameStr += "struct.anon"; - // We allow the actual size of the StructCounterName to be larger than - // MaxStructCounterNameSize and append $NumFields and at least one - // field type id. - // Append $NumFields. - NameStr += "$"; - Twine(StructTy->getNumElements()).toVector(NameStr); - // Append struct field type ids in the reverse order. - for (int i = StructTy->getNumElements() - 1; i >= 0; --i) { - NameStr += "$"; - Twine(StructTy->getElementType(i)->getTypeID()).toVector(NameStr); - if (NameStr.size() >= MaxStructCounterNameSize) - break; - } - if (StructTy->isLiteral()) { - // End with $ for literal struct. - NameStr += "$"; - } -} - -// Create global variables with auxiliary information (e.g., struct field size, -// offset, and type name) for better user report. -void EfficiencySanitizer::createCacheFragAuxGV( - Module &M, const DataLayout &DL, StructType *StructTy, - GlobalVariable *&TypeName, GlobalVariable *&Offset, - GlobalVariable *&Size) { - auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); - auto *Int32Ty = Type::getInt32Ty(*Ctx); - // FieldTypeName. - auto *TypeNameArrayTy = ArrayType::get(Int8PtrTy, StructTy->getNumElements()); - TypeName = new GlobalVariable(M, TypeNameArrayTy, true, - GlobalVariable::InternalLinkage, nullptr); - SmallVector<Constant *, 16> TypeNameVec; - // FieldOffset. - auto *OffsetArrayTy = ArrayType::get(Int32Ty, StructTy->getNumElements()); - Offset = new GlobalVariable(M, OffsetArrayTy, true, - GlobalVariable::InternalLinkage, nullptr); - SmallVector<Constant *, 16> OffsetVec; - // FieldSize - auto *SizeArrayTy = ArrayType::get(Int32Ty, StructTy->getNumElements()); - Size = new GlobalVariable(M, SizeArrayTy, true, - GlobalVariable::InternalLinkage, nullptr); - SmallVector<Constant *, 16> SizeVec; - for (unsigned i = 0; i < StructTy->getNumElements(); ++i) { - Type *Ty = StructTy->getElementType(i); - std::string Str; - raw_string_ostream StrOS(Str); - Ty->print(StrOS); - TypeNameVec.push_back( - ConstantExpr::getPointerCast( - createPrivateGlobalForString(M, StrOS.str(), true), - Int8PtrTy)); - OffsetVec.push_back( - ConstantInt::get(Int32Ty, - DL.getStructLayout(StructTy)->getElementOffset(i))); - SizeVec.push_back(ConstantInt::get(Int32Ty, - DL.getTypeAllocSize(Ty))); - } - TypeName->setInitializer(ConstantArray::get(TypeNameArrayTy, TypeNameVec)); - Offset->setInitializer(ConstantArray::get(OffsetArrayTy, OffsetVec)); - Size->setInitializer(ConstantArray::get(SizeArrayTy, SizeVec)); -} - -// Create the global variable for the cache-fragmentation tool. -GlobalVariable *EfficiencySanitizer::createCacheFragInfoGV( - Module &M, const DataLayout &DL, Constant *UnitName) { - assert(Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag); - - auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); - auto *Int8PtrPtrTy = Int8PtrTy->getPointerTo(); - auto *Int32Ty = Type::getInt32Ty(*Ctx); - auto *Int32PtrTy = Type::getInt32PtrTy(*Ctx); - auto *Int64Ty = Type::getInt64Ty(*Ctx); - auto *Int64PtrTy = Type::getInt64PtrTy(*Ctx); - // This structure should be kept consistent with the StructInfo struct - // in the runtime library. - // struct StructInfo { - // const char *StructName; - // u32 Size; - // u32 NumFields; - // u32 *FieldOffset; // auxiliary struct field info. - // u32 *FieldSize; // auxiliary struct field info. - // const char **FieldTypeName; // auxiliary struct field info. - // u64 *FieldCounters; - // u64 *ArrayCounter; - // }; - auto *StructInfoTy = - StructType::get(Int8PtrTy, Int32Ty, Int32Ty, Int32PtrTy, Int32PtrTy, - Int8PtrPtrTy, Int64PtrTy, Int64PtrTy); - auto *StructInfoPtrTy = StructInfoTy->getPointerTo(); - // This structure should be kept consistent with the CacheFragInfo struct - // in the runtime library. - // struct CacheFragInfo { - // const char *UnitName; - // u32 NumStructs; - // StructInfo *Structs; - // }; - auto *CacheFragInfoTy = StructType::get(Int8PtrTy, Int32Ty, StructInfoPtrTy); - - std::vector<StructType *> Vec = M.getIdentifiedStructTypes(); - unsigned NumStructs = 0; - SmallVector<Constant *, 16> Initializers; - - for (auto &StructTy : Vec) { - if (shouldIgnoreStructType(StructTy)) { - ++NumIgnoredStructs; - continue; - } - ++NumStructs; - - // StructName. - SmallString<MaxStructCounterNameSize> CounterNameStr; - createStructCounterName(StructTy, CounterNameStr); - GlobalVariable *StructCounterName = createPrivateGlobalForString( - M, CounterNameStr, /*AllowMerging*/true); - - // Counters. - // We create the counter array with StructCounterName and weak linkage - // so that the structs with the same name and layout from different - // compilation units will be merged into one. - auto *CounterArrayTy = ArrayType::get(Int64Ty, - getStructCounterSize(StructTy)); - GlobalVariable *Counters = - new GlobalVariable(M, CounterArrayTy, false, - GlobalVariable::WeakAnyLinkage, - ConstantAggregateZero::get(CounterArrayTy), - CounterNameStr); - - // Remember the counter variable for each struct type. - StructTyMap.insert(std::pair<Type *, GlobalVariable *>(StructTy, Counters)); - - // We pass the field type name array, offset array, and size array to - // the runtime for better reporting. - GlobalVariable *TypeName = nullptr, *Offset = nullptr, *Size = nullptr; - if (ClAuxFieldInfo) - createCacheFragAuxGV(M, DL, StructTy, TypeName, Offset, Size); - - Constant *FieldCounterIdx[2]; - FieldCounterIdx[0] = ConstantInt::get(Int32Ty, 0); - FieldCounterIdx[1] = ConstantInt::get(Int32Ty, - getFieldCounterIdx(StructTy)); - Constant *ArrayCounterIdx[2]; - ArrayCounterIdx[0] = ConstantInt::get(Int32Ty, 0); - ArrayCounterIdx[1] = ConstantInt::get(Int32Ty, - getArrayCounterIdx(StructTy)); - Initializers.push_back(ConstantStruct::get( - StructInfoTy, - ConstantExpr::getPointerCast(StructCounterName, Int8PtrTy), - ConstantInt::get(Int32Ty, - DL.getStructLayout(StructTy)->getSizeInBytes()), - ConstantInt::get(Int32Ty, StructTy->getNumElements()), - Offset == nullptr ? ConstantPointerNull::get(Int32PtrTy) - : ConstantExpr::getPointerCast(Offset, Int32PtrTy), - Size == nullptr ? ConstantPointerNull::get(Int32PtrTy) - : ConstantExpr::getPointerCast(Size, Int32PtrTy), - TypeName == nullptr - ? ConstantPointerNull::get(Int8PtrPtrTy) - : ConstantExpr::getPointerCast(TypeName, Int8PtrPtrTy), - ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, - FieldCounterIdx), - ConstantExpr::getGetElementPtr(CounterArrayTy, Counters, - ArrayCounterIdx))); - } - // Structs. - Constant *StructInfo; - if (NumStructs == 0) { - StructInfo = ConstantPointerNull::get(StructInfoPtrTy); - } else { - auto *StructInfoArrayTy = ArrayType::get(StructInfoTy, NumStructs); - StructInfo = ConstantExpr::getPointerCast( - new GlobalVariable(M, StructInfoArrayTy, false, - GlobalVariable::InternalLinkage, - ConstantArray::get(StructInfoArrayTy, Initializers)), - StructInfoPtrTy); - } - - auto *CacheFragInfoGV = new GlobalVariable( - M, CacheFragInfoTy, true, GlobalVariable::InternalLinkage, - ConstantStruct::get(CacheFragInfoTy, UnitName, - ConstantInt::get(Int32Ty, NumStructs), StructInfo)); - return CacheFragInfoGV; -} - -// Create the tool-specific argument passed to EsanInit and EsanExit. -Constant *EfficiencySanitizer::createEsanInitToolInfoArg(Module &M, - const DataLayout &DL) { - // This structure contains tool-specific information about each compilation - // unit (module) and is passed to the runtime library. - GlobalVariable *ToolInfoGV = nullptr; - - auto *Int8PtrTy = Type::getInt8PtrTy(*Ctx); - // Compilation unit name. - auto *UnitName = ConstantExpr::getPointerCast( - createPrivateGlobalForString(M, M.getModuleIdentifier(), true), - Int8PtrTy); - - // Create the tool-specific variable. - if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) - ToolInfoGV = createCacheFragInfoGV(M, DL, UnitName); - - if (ToolInfoGV != nullptr) - return ConstantExpr::getPointerCast(ToolInfoGV, Int8PtrTy); - - // Create the null pointer if no tool-specific variable created. - return ConstantPointerNull::get(Int8PtrTy); -} - -void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) { - PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx); - EsanDtorFunction = Function::Create(FunctionType::get(Type::getVoidTy(*Ctx), - false), - GlobalValue::InternalLinkage, - EsanModuleDtorName, &M); - ReturnInst::Create(*Ctx, BasicBlock::Create(*Ctx, "", EsanDtorFunction)); - IRBuilder<> IRB_Dtor(EsanDtorFunction->getEntryBlock().getTerminator()); - Function *EsanExit = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(EsanExitName, IRB_Dtor.getVoidTy(), - Int8PtrTy)); - EsanExit->setLinkage(Function::ExternalLinkage); - IRB_Dtor.CreateCall(EsanExit, {ToolInfoArg}); - appendToGlobalDtors(M, EsanDtorFunction, EsanCtorAndDtorPriority); -} - -bool EfficiencySanitizer::initOnModule(Module &M) { - - Triple TargetTriple(M.getTargetTriple()); - if (TargetTriple.isMIPS64()) - ShadowParams = ShadowParams40; - else - ShadowParams = ShadowParams47; - - Ctx = &M.getContext(); - const DataLayout &DL = M.getDataLayout(); - IRBuilder<> IRB(M.getContext()); - IntegerType *OrdTy = IRB.getInt32Ty(); - PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx); - IntptrTy = DL.getIntPtrType(M.getContext()); - // Create the variable passed to EsanInit and EsanExit. - Constant *ToolInfoArg = createEsanInitToolInfoArg(M, DL); - // Constructor - // We specify the tool type both in the EsanWhichToolName global - // and as an arg to the init routine as a sanity check. - std::tie(EsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( - M, EsanModuleCtorName, EsanInitName, /*InitArgTypes=*/{OrdTy, Int8PtrTy}, - /*InitArgs=*/{ - ConstantInt::get(OrdTy, static_cast<int>(Options.ToolType)), - ToolInfoArg}); - appendToGlobalCtors(M, EsanCtorFunction, EsanCtorAndDtorPriority); - - createDestructor(M, ToolInfoArg); - - new GlobalVariable(M, OrdTy, true, - GlobalValue::WeakAnyLinkage, - ConstantInt::get(OrdTy, - static_cast<int>(Options.ToolType)), - EsanWhichToolName); - - return true; -} - -Value *EfficiencySanitizer::appToShadow(Value *Shadow, IRBuilder<> &IRB) { - // Shadow = ((App & Mask) + Offs) >> Scale - Shadow = IRB.CreateAnd(Shadow, ConstantInt::get(IntptrTy, ShadowParams.ShadowMask)); - uint64_t Offs; - int Scale = ShadowScale[Options.ToolType]; - if (Scale <= 2) - Offs = ShadowParams.ShadowOffs[Scale]; - else - Offs = ShadowParams.ShadowOffs[0] << Scale; - Shadow = IRB.CreateAdd(Shadow, ConstantInt::get(IntptrTy, Offs)); - if (Scale > 0) - Shadow = IRB.CreateLShr(Shadow, Scale); - return Shadow; -} - -bool EfficiencySanitizer::shouldIgnoreMemoryAccess(Instruction *I) { - if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { - // We'd like to know about cache fragmentation in vtable accesses and - // constant data references, so we do not currently ignore anything. - return false; - } else if (Options.ToolType == EfficiencySanitizerOptions::ESAN_WorkingSet) { - // TODO: the instrumentation disturbs the data layout on the stack, so we - // may want to add an option to ignore stack references (if we can - // distinguish them) to reduce overhead. - } - // TODO(bruening): future tools will be returning true for some cases. - return false; -} - -bool EfficiencySanitizer::runOnModule(Module &M) { - bool Res = initOnModule(M); - initializeCallbacks(M); - for (auto &F : M) { - Res |= runOnFunction(F, M); - } - return Res; -} - -bool EfficiencySanitizer::runOnFunction(Function &F, Module &M) { - // This is required to prevent instrumenting the call to __esan_init from - // within the module constructor. - if (&F == EsanCtorFunction) - return false; - SmallVector<Instruction *, 8> LoadsAndStores; - SmallVector<Instruction *, 8> MemIntrinCalls; - SmallVector<Instruction *, 8> GetElementPtrs; - bool Res = false; - const DataLayout &DL = M.getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - - for (auto &BB : F) { - for (auto &Inst : BB) { - if ((isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || - isa<AtomicRMWInst>(Inst) || isa<AtomicCmpXchgInst>(Inst)) && - !shouldIgnoreMemoryAccess(&Inst)) - LoadsAndStores.push_back(&Inst); - else if (isa<MemIntrinsic>(Inst)) - MemIntrinCalls.push_back(&Inst); - else if (isa<GetElementPtrInst>(Inst)) - GetElementPtrs.push_back(&Inst); - else if (CallInst *CI = dyn_cast<CallInst>(&Inst)) - maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); - } - } - - if (ClInstrumentLoadsAndStores) { - for (auto Inst : LoadsAndStores) { - Res |= instrumentLoadOrStore(Inst, DL); - } - } - - if (ClInstrumentMemIntrinsics) { - for (auto Inst : MemIntrinCalls) { - Res |= instrumentMemIntrinsic(cast<MemIntrinsic>(Inst)); - } - } - - if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { - for (auto Inst : GetElementPtrs) { - Res |= instrumentGetElementPtr(Inst, M); - } - } - - return Res; -} - -bool EfficiencySanitizer::instrumentLoadOrStore(Instruction *I, - const DataLayout &DL) { - IRBuilder<> IRB(I); - bool IsStore; - Value *Addr; - unsigned Alignment; - if (LoadInst *Load = dyn_cast<LoadInst>(I)) { - IsStore = false; - Alignment = Load->getAlignment(); - Addr = Load->getPointerOperand(); - } else if (StoreInst *Store = dyn_cast<StoreInst>(I)) { - IsStore = true; - Alignment = Store->getAlignment(); - Addr = Store->getPointerOperand(); - } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { - IsStore = true; - Alignment = 0; - Addr = RMW->getPointerOperand(); - } else if (AtomicCmpXchgInst *Xchg = dyn_cast<AtomicCmpXchgInst>(I)) { - IsStore = true; - Alignment = 0; - Addr = Xchg->getPointerOperand(); - } else - llvm_unreachable("Unsupported mem access type"); - - Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); - const uint32_t TypeSizeBytes = DL.getTypeStoreSizeInBits(OrigTy) / 8; - Value *OnAccessFunc = nullptr; - - // Convert 0 to the default alignment. - if (Alignment == 0) - Alignment = DL.getPrefTypeAlignment(OrigTy); - - if (IsStore) - NumInstrumentedStores++; - else - NumInstrumentedLoads++; - int Idx = getMemoryAccessFuncIndex(Addr, DL); - if (Idx < 0) { - OnAccessFunc = IsStore ? EsanUnalignedStoreN : EsanUnalignedLoadN; - IRB.CreateCall(OnAccessFunc, - {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()), - ConstantInt::get(IntptrTy, TypeSizeBytes)}); - } else { - if (ClInstrumentFastpath && - instrumentFastpath(I, DL, IsStore, Addr, Alignment)) { - NumFastpaths++; - return true; - } - if (Alignment == 0 || (Alignment % TypeSizeBytes) == 0) - OnAccessFunc = IsStore ? EsanAlignedStore[Idx] : EsanAlignedLoad[Idx]; - else - OnAccessFunc = IsStore ? EsanUnalignedStore[Idx] : EsanUnalignedLoad[Idx]; - IRB.CreateCall(OnAccessFunc, - IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy())); - } - return true; -} - -// It's simplest to replace the memset/memmove/memcpy intrinsics with -// calls that the runtime library intercepts. -// Our pass is late enough that calls should not turn back into intrinsics. -bool EfficiencySanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { - IRBuilder<> IRB(MI); - bool Res = false; - if (isa<MemSetInst>(MI)) { - IRB.CreateCall( - MemsetFn, - {IRB.CreatePointerCast(MI->getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getArgOperand(1), IRB.getInt32Ty(), false), - IRB.CreateIntCast(MI->getArgOperand(2), IntptrTy, false)}); - MI->eraseFromParent(); - Res = true; - } else if (isa<MemTransferInst>(MI)) { - IRB.CreateCall( - isa<MemCpyInst>(MI) ? MemcpyFn : MemmoveFn, - {IRB.CreatePointerCast(MI->getArgOperand(0), IRB.getInt8PtrTy()), - IRB.CreatePointerCast(MI->getArgOperand(1), IRB.getInt8PtrTy()), - IRB.CreateIntCast(MI->getArgOperand(2), IntptrTy, false)}); - MI->eraseFromParent(); - Res = true; - } else - llvm_unreachable("Unsupported mem intrinsic type"); - return Res; -} - -bool EfficiencySanitizer::instrumentGetElementPtr(Instruction *I, Module &M) { - GetElementPtrInst *GepInst = dyn_cast<GetElementPtrInst>(I); - bool Res = false; - if (GepInst == nullptr || GepInst->getNumIndices() == 1) { - ++NumIgnoredGEPs; - return false; - } - Type *SourceTy = GepInst->getSourceElementType(); - StructType *StructTy = nullptr; - ConstantInt *Idx; - // Check if GEP calculates address from a struct array. - if (isa<StructType>(SourceTy)) { - StructTy = cast<StructType>(SourceTy); - Idx = dyn_cast<ConstantInt>(GepInst->getOperand(1)); - if ((Idx == nullptr || Idx->getSExtValue() != 0) && - !shouldIgnoreStructType(StructTy) && StructTyMap.count(StructTy) != 0) - Res |= insertCounterUpdate(I, StructTy, getArrayCounterIdx(StructTy)); - } - // Iterate all (except the first and the last) idx within each GEP instruction - // for possible nested struct field address calculation. - for (unsigned i = 1; i < GepInst->getNumIndices(); ++i) { - SmallVector<Value *, 8> IdxVec(GepInst->idx_begin(), - GepInst->idx_begin() + i); - Type *Ty = GetElementPtrInst::getIndexedType(SourceTy, IdxVec); - unsigned CounterIdx = 0; - if (isa<ArrayType>(Ty)) { - ArrayType *ArrayTy = cast<ArrayType>(Ty); - StructTy = dyn_cast<StructType>(ArrayTy->getElementType()); - if (shouldIgnoreStructType(StructTy) || StructTyMap.count(StructTy) == 0) - continue; - // The last counter for struct array access. - CounterIdx = getArrayCounterIdx(StructTy); - } else if (isa<StructType>(Ty)) { - StructTy = cast<StructType>(Ty); - if (shouldIgnoreStructType(StructTy) || StructTyMap.count(StructTy) == 0) - continue; - // Get the StructTy's subfield index. - Idx = cast<ConstantInt>(GepInst->getOperand(i+1)); - assert(Idx->getSExtValue() >= 0 && - Idx->getSExtValue() < StructTy->getNumElements()); - CounterIdx = getFieldCounterIdx(StructTy) + Idx->getSExtValue(); - } - Res |= insertCounterUpdate(I, StructTy, CounterIdx); - } - if (Res) - ++NumInstrumentedGEPs; - else - ++NumIgnoredGEPs; - return Res; -} - -bool EfficiencySanitizer::insertCounterUpdate(Instruction *I, - StructType *StructTy, - unsigned CounterIdx) { - GlobalVariable *CounterArray = StructTyMap[StructTy]; - if (CounterArray == nullptr) - return false; - IRBuilder<> IRB(I); - Constant *Indices[2]; - // Xref http://llvm.org/docs/LangRef.html#i-getelementptr and - // http://llvm.org/docs/GetElementPtr.html. - // The first index of the GEP instruction steps through the first operand, - // i.e., the array itself. - Indices[0] = ConstantInt::get(IRB.getInt32Ty(), 0); - // The second index is the index within the array. - Indices[1] = ConstantInt::get(IRB.getInt32Ty(), CounterIdx); - Constant *Counter = - ConstantExpr::getGetElementPtr( - ArrayType::get(IRB.getInt64Ty(), getStructCounterSize(StructTy)), - CounterArray, Indices); - Value *Load = IRB.CreateLoad(Counter); - IRB.CreateStore(IRB.CreateAdd(Load, ConstantInt::get(IRB.getInt64Ty(), 1)), - Counter); - return true; -} - -int EfficiencySanitizer::getMemoryAccessFuncIndex(Value *Addr, - const DataLayout &DL) { - Type *OrigPtrTy = Addr->getType(); - Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType(); - assert(OrigTy->isSized()); - // The size is always a multiple of 8. - uint32_t TypeSizeBytes = DL.getTypeStoreSizeInBits(OrigTy) / 8; - if (TypeSizeBytes != 1 && TypeSizeBytes != 2 && TypeSizeBytes != 4 && - TypeSizeBytes != 8 && TypeSizeBytes != 16) { - // Irregular sizes do not have per-size call targets. - NumAccessesWithIrregularSize++; - return -1; - } - size_t Idx = countTrailingZeros(TypeSizeBytes); - assert(Idx < NumberOfAccessSizes); - return Idx; -} - -bool EfficiencySanitizer::instrumentFastpath(Instruction *I, - const DataLayout &DL, bool IsStore, - Value *Addr, unsigned Alignment) { - if (Options.ToolType == EfficiencySanitizerOptions::ESAN_CacheFrag) { - return instrumentFastpathCacheFrag(I, DL, Addr, Alignment); - } else if (Options.ToolType == EfficiencySanitizerOptions::ESAN_WorkingSet) { - return instrumentFastpathWorkingSet(I, DL, Addr, Alignment); - } - return false; -} - -bool EfficiencySanitizer::instrumentFastpathCacheFrag(Instruction *I, - const DataLayout &DL, - Value *Addr, - unsigned Alignment) { - // Do nothing. - return true; // Return true to avoid slowpath instrumentation. -} - -bool EfficiencySanitizer::instrumentFastpathWorkingSet( - Instruction *I, const DataLayout &DL, Value *Addr, unsigned Alignment) { - assert(ShadowScale[Options.ToolType] == 6); // The code below assumes this - IRBuilder<> IRB(I); - Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); - const uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); - // Bail to the slowpath if the access might touch multiple cache lines. - // An access aligned to its size is guaranteed to be intra-cache-line. - // getMemoryAccessFuncIndex has already ruled out a size larger than 16 - // and thus larger than a cache line for platforms this tool targets - // (and our shadow memory setup assumes 64-byte cache lines). - assert(TypeSize <= 128); - if (!(TypeSize == 8 || - (Alignment % (TypeSize / 8)) == 0)) { - if (ClAssumeIntraCacheLine) - ++NumAssumedIntraCacheLine; - else - return false; - } - - // We inline instrumentation to set the corresponding shadow bits for - // each cache line touched by the application. Here we handle a single - // load or store where we've already ruled out the possibility that it - // might touch more than one cache line and thus we simply update the - // shadow memory for a single cache line. - // Our shadow memory model is fine with races when manipulating shadow values. - // We generate the following code: - // - // const char BitMask = 0x81; - // char *ShadowAddr = appToShadow(AppAddr); - // if ((*ShadowAddr & BitMask) != BitMask) - // *ShadowAddr |= Bitmask; - // - Value *AddrPtr = IRB.CreatePointerCast(Addr, IntptrTy); - Value *ShadowPtr = appToShadow(AddrPtr, IRB); - Type *ShadowTy = IntegerType::get(*Ctx, 8U); - Type *ShadowPtrTy = PointerType::get(ShadowTy, 0); - // The bottom bit is used for the current sampling period's working set. - // The top bit is used for the total working set. We set both on each - // memory access, if they are not already set. - Value *ValueMask = ConstantInt::get(ShadowTy, 0x81); // 10000001B - - Value *OldValue = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); - // The AND and CMP will be turned into a TEST instruction by the compiler. - Value *Cmp = IRB.CreateICmpNE(IRB.CreateAnd(OldValue, ValueMask), ValueMask); - Instruction *CmpTerm = SplitBlockAndInsertIfThen(Cmp, I, false); - // FIXME: do I need to call SetCurrentDebugLocation? - IRB.SetInsertPoint(CmpTerm); - // We use OR to set the shadow bits to avoid corrupting the middle 6 bits, - // which are used by the runtime library. - Value *NewVal = IRB.CreateOr(OldValue, ValueMask); - IRB.CreateStore(NewVal, IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); - IRB.SetInsertPoint(I); - - return true; -} diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 9af64ed332cd..59950ffc4e9a 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -1,9 +1,8 @@ //===- GCOVProfiling.cpp - Insert edge counters for gcov profiling --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -103,11 +102,11 @@ private: std::vector<Regex> &Regexes); // Get pointers to the functions in the runtime library. - Constant *getStartFileFunc(); - Constant *getEmitFunctionFunc(); - Constant *getEmitArcsFunc(); - Constant *getSummaryInfoFunc(); - Constant *getEndFileFunc(); + FunctionCallee getStartFileFunc(); + FunctionCallee getEmitFunctionFunc(); + FunctionCallee getEmitArcsFunc(); + FunctionCallee getSummaryInfoFunc(); + FunctionCallee getEndFileFunc(); // Add the function to write out all our counters to the global destructor // list. @@ -648,7 +647,7 @@ void GCOVProfiler::AddFlushBeforeForkAndExec() { for (auto I : ForkAndExecs) { IRBuilder<> Builder(I); FunctionType *FTy = FunctionType::get(Builder.getVoidTy(), {}, false); - Constant *GCOVFlush = M->getOrInsertFunction("__gcov_flush", FTy); + FunctionCallee GCOVFlush = M->getOrInsertFunction("__gcov_flush", FTy); Builder.CreateCall(GCOVFlush); I->getParent()->splitBasicBlock(I); } @@ -811,14 +810,14 @@ bool GCOVProfiler::emitProfileArcs() { auto It = EdgeToCounter.find({Pred, &BB}); assert(It != EdgeToCounter.end()); const unsigned Edge = It->second; - Value *EdgeCounter = - BuilderForPhi.CreateConstInBoundsGEP2_64(Counters, 0, Edge); + Value *EdgeCounter = BuilderForPhi.CreateConstInBoundsGEP2_64( + Counters->getValueType(), Counters, 0, Edge); Phi->addIncoming(EdgeCounter, Pred); } // Skip phis, landingpads. IRBuilder<> Builder(&*BB.getFirstInsertionPt()); - Value *Count = Builder.CreateLoad(Phi); + Value *Count = Builder.CreateLoad(Builder.getInt64Ty(), Phi); Count = Builder.CreateAdd(Count, Builder.getInt64(1)); Builder.CreateStore(Count, Phi); @@ -827,9 +826,9 @@ bool GCOVProfiler::emitProfileArcs() { auto It = EdgeToCounter.find({&BB, nullptr}); assert(It != EdgeToCounter.end()); const unsigned Edge = It->second; - Value *Counter = - Builder.CreateConstInBoundsGEP2_64(Counters, 0, Edge); - Value *Count = Builder.CreateLoad(Counter); + Value *Counter = Builder.CreateConstInBoundsGEP2_64( + Counters->getValueType(), Counters, 0, Edge); + Value *Count = Builder.CreateLoad(Builder.getInt64Ty(), Counter); Count = Builder.CreateAdd(Count, Builder.getInt64(1)); Builder.CreateStore(Count, Counter); } @@ -864,7 +863,7 @@ bool GCOVProfiler::emitProfileArcs() { // Initialize the environment and register the local writeout and flush // functions. - Constant *GCOVInit = M->getOrInsertFunction("llvm_gcov_init", FTy); + FunctionCallee GCOVInit = M->getOrInsertFunction("llvm_gcov_init", FTy); Builder.CreateCall(GCOVInit, {WriteoutF, FlushF}); Builder.CreateRetVoid(); @@ -874,22 +873,21 @@ bool GCOVProfiler::emitProfileArcs() { return Result; } -Constant *GCOVProfiler::getStartFileFunc() { +FunctionCallee GCOVProfiler::getStartFileFunc() { Type *Args[] = { Type::getInt8PtrTy(*Ctx), // const char *orig_filename Type::getInt8PtrTy(*Ctx), // const char version[4] Type::getInt32Ty(*Ctx), // uint32_t checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - auto *Res = M->getOrInsertFunction("llvm_gcda_start_file", FTy); - if (Function *FunRes = dyn_cast<Function>(Res)) - if (auto AK = TLI->getExtAttrForI32Param(false)) - FunRes->addParamAttr(2, AK); + AttributeList AL; + if (auto AK = TLI->getExtAttrForI32Param(false)) + AL = AL.addParamAttribute(*Ctx, 2, AK); + FunctionCallee Res = M->getOrInsertFunction("llvm_gcda_start_file", FTy, AL); return Res; - } -Constant *GCOVProfiler::getEmitFunctionFunc() { +FunctionCallee GCOVProfiler::getEmitFunctionFunc() { Type *Args[] = { Type::getInt32Ty(*Ctx), // uint32_t ident Type::getInt8PtrTy(*Ctx), // const char *function_name @@ -898,36 +896,34 @@ Constant *GCOVProfiler::getEmitFunctionFunc() { Type::getInt32Ty(*Ctx), // uint32_t cfg_checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - auto *Res = M->getOrInsertFunction("llvm_gcda_emit_function", FTy); - if (Function *FunRes = dyn_cast<Function>(Res)) - if (auto AK = TLI->getExtAttrForI32Param(false)) { - FunRes->addParamAttr(0, AK); - FunRes->addParamAttr(2, AK); - FunRes->addParamAttr(3, AK); - FunRes->addParamAttr(4, AK); - } - return Res; + AttributeList AL; + if (auto AK = TLI->getExtAttrForI32Param(false)) { + AL = AL.addParamAttribute(*Ctx, 0, AK); + AL = AL.addParamAttribute(*Ctx, 2, AK); + AL = AL.addParamAttribute(*Ctx, 3, AK); + AL = AL.addParamAttribute(*Ctx, 4, AK); + } + return M->getOrInsertFunction("llvm_gcda_emit_function", FTy); } -Constant *GCOVProfiler::getEmitArcsFunc() { +FunctionCallee GCOVProfiler::getEmitArcsFunc() { Type *Args[] = { Type::getInt32Ty(*Ctx), // uint32_t num_counters Type::getInt64PtrTy(*Ctx), // uint64_t *counters }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - auto *Res = M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy); - if (Function *FunRes = dyn_cast<Function>(Res)) - if (auto AK = TLI->getExtAttrForI32Param(false)) - FunRes->addParamAttr(0, AK); - return Res; + AttributeList AL; + if (auto AK = TLI->getExtAttrForI32Param(false)) + AL = AL.addParamAttribute(*Ctx, 0, AK); + return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy, AL); } -Constant *GCOVProfiler::getSummaryInfoFunc() { +FunctionCallee GCOVProfiler::getSummaryInfoFunc() { FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); return M->getOrInsertFunction("llvm_gcda_summary_info", FTy); } -Constant *GCOVProfiler::getEndFileFunc() { +FunctionCallee GCOVProfiler::getEndFileFunc() { FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); return M->getOrInsertFunction("llvm_gcda_end_file", FTy); } @@ -947,11 +943,11 @@ Function *GCOVProfiler::insertCounterWriteout( BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", WriteoutF); IRBuilder<> Builder(BB); - Constant *StartFile = getStartFileFunc(); - Constant *EmitFunction = getEmitFunctionFunc(); - Constant *EmitArcs = getEmitArcsFunc(); - Constant *SummaryInfo = getSummaryInfoFunc(); - Constant *EndFile = getEndFileFunc(); + FunctionCallee StartFile = getStartFileFunc(); + FunctionCallee EmitFunction = getEmitFunctionFunc(); + FunctionCallee EmitArcs = getEmitArcsFunc(); + FunctionCallee SummaryInfo = getSummaryInfoFunc(); + FunctionCallee EndFile = getEndFileFunc(); NamedMDNode *CUNodes = M->getNamedMetadata("llvm.dbg.cu"); if (!CUNodes) { @@ -1088,22 +1084,32 @@ Function *GCOVProfiler::insertCounterWriteout( PHINode *IV = Builder.CreatePHI(Builder.getInt32Ty(), /*NumReservedValues*/ 2); IV->addIncoming(Builder.getInt32(0), BB); - auto *FileInfoPtr = - Builder.CreateInBoundsGEP(FileInfoArrayGV, {Builder.getInt32(0), IV}); - auto *StartFileCallArgsPtr = Builder.CreateStructGEP(FileInfoPtr, 0); + auto *FileInfoPtr = Builder.CreateInBoundsGEP( + FileInfoArrayTy, FileInfoArrayGV, {Builder.getInt32(0), IV}); + auto *StartFileCallArgsPtr = + Builder.CreateStructGEP(FileInfoTy, FileInfoPtr, 0); auto *StartFileCall = Builder.CreateCall( StartFile, - {Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 0)), - Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 1)), - Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 2))}); + {Builder.CreateLoad(StartFileCallArgsTy->getElementType(0), + Builder.CreateStructGEP(StartFileCallArgsTy, + StartFileCallArgsPtr, 0)), + Builder.CreateLoad(StartFileCallArgsTy->getElementType(1), + Builder.CreateStructGEP(StartFileCallArgsTy, + StartFileCallArgsPtr, 1)), + Builder.CreateLoad(StartFileCallArgsTy->getElementType(2), + Builder.CreateStructGEP(StartFileCallArgsTy, + StartFileCallArgsPtr, 2))}); if (auto AK = TLI->getExtAttrForI32Param(false)) StartFileCall->addParamAttr(2, AK); auto *NumCounters = - Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 1)); + Builder.CreateLoad(FileInfoTy->getElementType(1), + Builder.CreateStructGEP(FileInfoTy, FileInfoPtr, 1)); auto *EmitFunctionCallArgsArray = - Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 2)); + Builder.CreateLoad(FileInfoTy->getElementType(2), + Builder.CreateStructGEP(FileInfoTy, FileInfoPtr, 2)); auto *EmitArcsCallArgsArray = - Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 3)); + Builder.CreateLoad(FileInfoTy->getElementType(3), + Builder.CreateStructGEP(FileInfoTy, FileInfoPtr, 3)); auto *EnterCounterLoopCond = Builder.CreateICmpSLT(Builder.getInt32(0), NumCounters); Builder.CreateCondBr(EnterCounterLoopCond, CounterLoopHeader, FileLoopLatch); @@ -1111,16 +1117,26 @@ Function *GCOVProfiler::insertCounterWriteout( Builder.SetInsertPoint(CounterLoopHeader); auto *JV = Builder.CreatePHI(Builder.getInt32Ty(), /*NumReservedValues*/ 2); JV->addIncoming(Builder.getInt32(0), FileLoopHeader); - auto *EmitFunctionCallArgsPtr = - Builder.CreateInBoundsGEP(EmitFunctionCallArgsArray, {JV}); + auto *EmitFunctionCallArgsPtr = Builder.CreateInBoundsGEP( + EmitFunctionCallArgsTy, EmitFunctionCallArgsArray, JV); auto *EmitFunctionCall = Builder.CreateCall( EmitFunction, - {Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 0)), - Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 1)), - Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 2)), - Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 3)), - Builder.CreateLoad( - Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 4))}); + {Builder.CreateLoad(EmitFunctionCallArgsTy->getElementType(0), + Builder.CreateStructGEP(EmitFunctionCallArgsTy, + EmitFunctionCallArgsPtr, 0)), + Builder.CreateLoad(EmitFunctionCallArgsTy->getElementType(1), + Builder.CreateStructGEP(EmitFunctionCallArgsTy, + EmitFunctionCallArgsPtr, 1)), + Builder.CreateLoad(EmitFunctionCallArgsTy->getElementType(2), + Builder.CreateStructGEP(EmitFunctionCallArgsTy, + EmitFunctionCallArgsPtr, 2)), + Builder.CreateLoad(EmitFunctionCallArgsTy->getElementType(3), + Builder.CreateStructGEP(EmitFunctionCallArgsTy, + EmitFunctionCallArgsPtr, 3)), + Builder.CreateLoad(EmitFunctionCallArgsTy->getElementType(4), + Builder.CreateStructGEP(EmitFunctionCallArgsTy, + EmitFunctionCallArgsPtr, + 4))}); if (auto AK = TLI->getExtAttrForI32Param(false)) { EmitFunctionCall->addParamAttr(0, AK); EmitFunctionCall->addParamAttr(2, AK); @@ -1128,11 +1144,15 @@ Function *GCOVProfiler::insertCounterWriteout( EmitFunctionCall->addParamAttr(4, AK); } auto *EmitArcsCallArgsPtr = - Builder.CreateInBoundsGEP(EmitArcsCallArgsArray, {JV}); + Builder.CreateInBoundsGEP(EmitArcsCallArgsTy, EmitArcsCallArgsArray, JV); auto *EmitArcsCall = Builder.CreateCall( EmitArcs, - {Builder.CreateLoad(Builder.CreateStructGEP(EmitArcsCallArgsPtr, 0)), - Builder.CreateLoad(Builder.CreateStructGEP(EmitArcsCallArgsPtr, 1))}); + {Builder.CreateLoad( + EmitArcsCallArgsTy->getElementType(0), + Builder.CreateStructGEP(EmitArcsCallArgsTy, EmitArcsCallArgsPtr, 0)), + Builder.CreateLoad(EmitArcsCallArgsTy->getElementType(1), + Builder.CreateStructGEP(EmitArcsCallArgsTy, + EmitArcsCallArgsPtr, 1))}); if (auto AK = TLI->getExtAttrForI32Param(false)) EmitArcsCall->addParamAttr(0, AK); auto *NextJV = Builder.CreateAdd(JV, Builder.getInt32(1)); @@ -1172,7 +1192,7 @@ insertFlush(ArrayRef<std::pair<GlobalVariable*, MDNode*> > CountersBySP) { BasicBlock *Entry = BasicBlock::Create(*Ctx, "entry", FlushF); // Write out the current counters. - Constant *WriteoutF = M->getFunction("__llvm_gcov_writeout"); + Function *WriteoutF = M->getFunction("__llvm_gcov_writeout"); assert(WriteoutF && "Need to create the writeout function first!"); IRBuilder<> Builder(Entry); diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index d04c2b76288f..90a9f4955a4b 100644 --- a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -1,9 +1,8 @@ //===- HWAddressSanitizer.cpp - detector of uninitialized reads -------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -12,6 +11,7 @@ /// based on tagged addressing. //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -21,6 +21,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -125,10 +126,10 @@ static cl::opt<bool> ClEnableKhwasan( // is accessed. The shadow mapping looks like: // Shadow = (Mem >> scale) + offset -static cl::opt<unsigned long long> ClMappingOffset( - "hwasan-mapping-offset", - cl::desc("HWASan shadow mapping offset [EXPERIMENTAL]"), cl::Hidden, - cl::init(0)); +static cl::opt<uint64_t> + ClMappingOffset("hwasan-mapping-offset", + cl::desc("HWASan shadow mapping offset [EXPERIMENTAL]"), + cl::Hidden, cl::init(0)); static cl::opt<bool> ClWithIfunc("hwasan-with-ifunc", @@ -148,42 +149,46 @@ static cl::opt<bool> "in a thread-local ring buffer"), cl::Hidden, cl::init(true)); static cl::opt<bool> - ClCreateFrameDescriptions("hwasan-create-frame-descriptions", - cl::desc("create static frame descriptions"), - cl::Hidden, cl::init(true)); - -static cl::opt<bool> ClInstrumentMemIntrinsics("hwasan-instrument-mem-intrinsics", cl::desc("instrument memory intrinsics"), cl::Hidden, cl::init(true)); + +static cl::opt<bool> + ClInstrumentLandingPads("hwasan-instrument-landing-pads", + cl::desc("instrument landing pads"), cl::Hidden, + cl::init(true)); + +static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks", + cl::desc("inline all checks"), + cl::Hidden, cl::init(false)); + namespace { /// An instrumentation pass implementing detection of addressability bugs /// using tagged pointers. -class HWAddressSanitizer : public FunctionPass { +class HWAddressSanitizer { public: - // Pass identification, replacement for typeid. - static char ID; - - explicit HWAddressSanitizer(bool CompileKernel = false, bool Recover = false) - : FunctionPass(ID) { + explicit HWAddressSanitizer(Module &M, bool CompileKernel = false, + bool Recover = false) { this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; this->CompileKernel = ClEnableKhwasan.getNumOccurrences() > 0 ? ClEnableKhwasan : CompileKernel; - } - StringRef getPassName() const override { return "HWAddressSanitizer"; } + initializeModule(M); + } - bool runOnFunction(Function &F) override; - bool doInitialization(Module &M) override; + bool sanitizeFunction(Function &F); + void initializeModule(Module &M); void initializeCallbacks(Module &M); + Value *getDynamicShadowIfunc(IRBuilder<> &IRB); Value *getDynamicShadowNonTls(IRBuilder<> &IRB); void untagPointerOperand(Instruction *I, Value *Addr); - Value *memToShadow(Value *Shadow, Type *Ty, IRBuilder<> &IRB); - void instrumentMemAccessInline(Value *PtrLong, bool IsWrite, + Value *shadowBase(); + Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); + void instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore); void instrumentMemIntrinsic(MemIntrinsic *MI); @@ -193,11 +198,15 @@ public: Value **MaybeMask); bool isInterestingAlloca(const AllocaInst &AI); - bool tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag); + bool tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); - bool instrumentStack(SmallVectorImpl<AllocaInst *> &Allocas, - SmallVectorImpl<Instruction *> &RetVec, Value *StackTag); + bool instrumentStack( + SmallVectorImpl<AllocaInst *> &Allocas, + DenseMap<AllocaInst *, std::vector<DbgDeclareInst *>> &AllocaDeclareMap, + SmallVectorImpl<Instruction *> &RetVec, Value *StackTag); + Value *readRegister(IRBuilder<> &IRB, StringRef Name); + bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); Value *getStackBaseTag(IRBuilder<> &IRB); Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, AllocaInst *AI, @@ -205,31 +214,14 @@ public: Value *getUARTag(IRBuilder<> &IRB, Value *StackTag); Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty); - Value *emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord); + void emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord); private: LLVMContext *C; std::string CurModuleUniqueId; Triple TargetTriple; - Function *HWAsanMemmove, *HWAsanMemcpy, *HWAsanMemset; - - // Frame description is a way to pass names/sizes of local variables - // to the run-time w/o adding extra executable code in every function. - // We do this by creating a separate section with {PC,Descr} pairs and passing - // the section beg/end to __hwasan_init_frames() at module init time. - std::string createFrameString(ArrayRef<AllocaInst*> Allocas); - void createFrameGlobal(Function &F, const std::string &FrameString); - // Get the section name for frame descriptions. Currently ELF-only. - const char *getFrameSection() { return "__hwasan_frames"; } - const char *getFrameSectionBeg() { return "__start___hwasan_frames"; } - const char *getFrameSectionEnd() { return "__stop___hwasan_frames"; } - GlobalVariable *createFrameSectionBound(Module &M, Type *Ty, - const char *Name) { - auto GV = new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, - nullptr, Name); - GV->setVisibility(GlobalValue::HiddenVisibility); - return GV; - } + FunctionCallee HWAsanMemmove, HWAsanMemcpy, HWAsanMemset; + FunctionCallee HWAsanHandleVfork; /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) + Offset. @@ -253,48 +245,95 @@ private: Type *IntptrTy; Type *Int8PtrTy; Type *Int8Ty; + Type *Int32Ty; bool CompileKernel; bool Recover; Function *HwasanCtorFunction; - Function *HwasanMemoryAccessCallback[2][kNumberOfAccessSizes]; - Function *HwasanMemoryAccessCallbackSized[2]; + FunctionCallee HwasanMemoryAccessCallback[2][kNumberOfAccessSizes]; + FunctionCallee HwasanMemoryAccessCallbackSized[2]; - Function *HwasanTagMemoryFunc; - Function *HwasanGenerateTagFunc; - Function *HwasanThreadEnterFunc; + FunctionCallee HwasanTagMemoryFunc; + FunctionCallee HwasanGenerateTagFunc; + FunctionCallee HwasanThreadEnterFunc; Constant *ShadowGlobal; Value *LocalDynamicShadow = nullptr; + Value *StackBaseTag = nullptr; GlobalValue *ThreadPtrGlobal = nullptr; }; +class HWAddressSanitizerLegacyPass : public FunctionPass { +public: + // Pass identification, replacement for typeid. + static char ID; + + explicit HWAddressSanitizerLegacyPass(bool CompileKernel = false, + bool Recover = false) + : FunctionPass(ID), CompileKernel(CompileKernel), Recover(Recover) {} + + StringRef getPassName() const override { return "HWAddressSanitizer"; } + + bool doInitialization(Module &M) override { + HWASan = llvm::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover); + return true; + } + + bool runOnFunction(Function &F) override { + return HWASan->sanitizeFunction(F); + } + + bool doFinalization(Module &M) override { + HWASan.reset(); + return false; + } + +private: + std::unique_ptr<HWAddressSanitizer> HWASan; + bool CompileKernel; + bool Recover; +}; + } // end anonymous namespace -char HWAddressSanitizer::ID = 0; +char HWAddressSanitizerLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN( - HWAddressSanitizer, "hwasan", + HWAddressSanitizerLegacyPass, "hwasan", "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) INITIALIZE_PASS_END( - HWAddressSanitizer, "hwasan", + HWAddressSanitizerLegacyPass, "hwasan", "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) -FunctionPass *llvm::createHWAddressSanitizerPass(bool CompileKernel, - bool Recover) { +FunctionPass *llvm::createHWAddressSanitizerLegacyPassPass(bool CompileKernel, + bool Recover) { assert(!CompileKernel || Recover); - return new HWAddressSanitizer(CompileKernel, Recover); + return new HWAddressSanitizerLegacyPass(CompileKernel, Recover); +} + +HWAddressSanitizerPass::HWAddressSanitizerPass(bool CompileKernel, bool Recover) + : CompileKernel(CompileKernel), Recover(Recover) {} + +PreservedAnalyses HWAddressSanitizerPass::run(Module &M, + ModuleAnalysisManager &MAM) { + HWAddressSanitizer HWASan(M, CompileKernel, Recover); + bool Modified = false; + for (Function &F : M) + Modified |= HWASan.sanitizeFunction(F); + if (Modified) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); } /// Module-level initialization. /// /// inserts a call to __hwasan_init to the module's constructor list. -bool HWAddressSanitizer::doInitialization(Module &M) { +void HWAddressSanitizer::initializeModule(Module &M) { LLVM_DEBUG(dbgs() << "Init " << M.getName() << "\n"); auto &DL = M.getDataLayout(); @@ -308,47 +347,35 @@ bool HWAddressSanitizer::doInitialization(Module &M) { IntptrTy = IRB.getIntPtrTy(DL); Int8PtrTy = IRB.getInt8PtrTy(); Int8Ty = IRB.getInt8Ty(); + Int32Ty = IRB.getInt32Ty(); HwasanCtorFunction = nullptr; if (!CompileKernel) { std::tie(HwasanCtorFunction, std::ignore) = - createSanitizerCtorAndInitFunctions(M, kHwasanModuleCtorName, - kHwasanInitName, - /*InitArgTypes=*/{}, - /*InitArgs=*/{}); - Comdat *CtorComdat = M.getOrInsertComdat(kHwasanModuleCtorName); - HwasanCtorFunction->setComdat(CtorComdat); - appendToGlobalCtors(M, HwasanCtorFunction, 0, HwasanCtorFunction); - - // Create a zero-length global in __hwasan_frame so that the linker will - // always create start and stop symbols. - // - // N.B. If we ever start creating associated metadata in this pass this - // global will need to be associated with the ctor. - Type *Int8Arr0Ty = ArrayType::get(Int8Ty, 0); - auto GV = - new GlobalVariable(M, Int8Arr0Ty, /*isConstantGlobal*/ true, - GlobalVariable::PrivateLinkage, - Constant::getNullValue(Int8Arr0Ty), "__hwasan"); - GV->setSection(getFrameSection()); - GV->setComdat(CtorComdat); - appendToCompilerUsed(M, GV); - - IRBuilder<> IRBCtor(HwasanCtorFunction->getEntryBlock().getTerminator()); - IRBCtor.CreateCall( - declareSanitizerInitFunction(M, "__hwasan_init_frames", - {Int8PtrTy, Int8PtrTy}), - {createFrameSectionBound(M, Int8Ty, getFrameSectionBeg()), - createFrameSectionBound(M, Int8Ty, getFrameSectionEnd())}); + getOrCreateSanitizerCtorAndInitFunctions( + M, kHwasanModuleCtorName, kHwasanInitName, + /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, FunctionCallee) { + Comdat *CtorComdat = M.getOrInsertComdat(kHwasanModuleCtorName); + Ctor->setComdat(CtorComdat); + appendToGlobalCtors(M, Ctor, 0, Ctor); + }); } - if (!TargetTriple.isAndroid()) - appendToCompilerUsed( - M, ThreadPtrGlobal = new GlobalVariable( - M, IntptrTy, false, GlobalVariable::ExternalLinkage, nullptr, - "__hwasan_tls", nullptr, GlobalVariable::InitialExecTLSModel)); - - return true; + if (!TargetTriple.isAndroid()) { + Constant *C = M.getOrInsertGlobal("__hwasan_tls", IntptrTy, [&] { + auto *GV = new GlobalVariable(M, IntptrTy, /*isConstant=*/false, + GlobalValue::ExternalLinkage, nullptr, + "__hwasan_tls", nullptr, + GlobalVariable::InitialExecTLSModel); + appendToCompilerUsed(M, GV); + return GV; + }); + ThreadPtrGlobal = cast<GlobalVariable>(C); + } } void HWAddressSanitizer::initializeCallbacks(Module &M) { @@ -357,44 +384,55 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; const std::string EndingStr = Recover ? "_noabort" : ""; - HwasanMemoryAccessCallbackSized[AccessIsWrite] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + TypeStr + "N" + EndingStr, - FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false))); + HwasanMemoryAccessCallbackSized[AccessIsWrite] = M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + TypeStr + "N" + EndingStr, + FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false)); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { HwasanMemoryAccessCallback[AccessIsWrite][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( + M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + TypeStr + itostr(1ULL << AccessSizeIndex) + EndingStr, - FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false))); + FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false)); } } - HwasanTagMemoryFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__hwasan_tag_memory", IRB.getVoidTy(), Int8PtrTy, Int8Ty, IntptrTy)); - HwasanGenerateTagFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty)); + HwasanTagMemoryFunc = M.getOrInsertFunction( + "__hwasan_tag_memory", IRB.getVoidTy(), Int8PtrTy, Int8Ty, IntptrTy); + HwasanGenerateTagFunc = + M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty); - if (Mapping.InGlobal) - ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow", - ArrayType::get(IRB.getInt8Ty(), 0)); + ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow", + ArrayType::get(IRB.getInt8Ty(), 0)); const std::string MemIntrinCallbackPrefix = CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; - HWAsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); - HWAsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); - HWAsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy)); - - HwasanThreadEnterFunc = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__hwasan_thread_enter", IRB.getVoidTy())); + HWAsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy); + HWAsanMemcpy = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memcpy", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IntptrTy); + HWAsanMemset = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memset", + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + IRB.getInt32Ty(), IntptrTy); + + HWAsanHandleVfork = + M.getOrInsertFunction("__hwasan_handle_vfork", IRB.getVoidTy(), IntptrTy); + + HwasanThreadEnterFunc = + M.getOrInsertFunction("__hwasan_thread_enter", IRB.getVoidTy()); +} + +Value *HWAddressSanitizer::getDynamicShadowIfunc(IRBuilder<> &IRB) { + // An empty inline asm with input reg == output reg. + // An opaque no-op cast, basically. + InlineAsm *Asm = InlineAsm::get( + FunctionType::get(Int8PtrTy, {ShadowGlobal->getType()}, false), + StringRef(""), StringRef("=r,0"), + /*hasSideEffects=*/false); + return IRB.CreateCall(Asm, {ShadowGlobal}, ".hwasan.shadow"); } Value *HWAddressSanitizer::getDynamicShadowNonTls(IRBuilder<> &IRB) { @@ -403,18 +441,12 @@ Value *HWAddressSanitizer::getDynamicShadowNonTls(IRBuilder<> &IRB) { return nullptr; if (Mapping.InGlobal) { - // An empty inline asm with input reg == output reg. - // An opaque pointer-to-int cast, basically. - InlineAsm *Asm = InlineAsm::get( - FunctionType::get(IntptrTy, {ShadowGlobal->getType()}, false), - StringRef(""), StringRef("=r,0"), - /*hasSideEffects=*/false); - return IRB.CreateCall(Asm, {ShadowGlobal}, ".hwasan.shadow"); + return getDynamicShadowIfunc(IRB); } else { Value *GlobalDynamicAddress = IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal( - kHwasanShadowMemoryDynamicAddress, IntptrTy); - return IRB.CreateLoad(GlobalDynamicAddress); + kHwasanShadowMemoryDynamicAddress, Int8PtrTy); + return IRB.CreateLoad(Int8PtrTy, GlobalDynamicAddress); } } @@ -506,29 +538,44 @@ void HWAddressSanitizer::untagPointerOperand(Instruction *I, Value *Addr) { I->setOperand(getPointerOperandIndex(I), UntaggedPtr); } -Value *HWAddressSanitizer::memToShadow(Value *Mem, Type *Ty, IRBuilder<> &IRB) { +Value *HWAddressSanitizer::shadowBase() { + if (LocalDynamicShadow) + return LocalDynamicShadow; + return ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, Mapping.Offset), + Int8PtrTy); +} + +Value *HWAddressSanitizer::memToShadow(Value *Mem, IRBuilder<> &IRB) { // Mem >> Scale Value *Shadow = IRB.CreateLShr(Mem, Mapping.Scale); if (Mapping.Offset == 0) - return Shadow; + return IRB.CreateIntToPtr(Shadow, Int8PtrTy); // (Mem >> Scale) + Offset - Value *ShadowBase; - if (LocalDynamicShadow) - ShadowBase = LocalDynamicShadow; - else - ShadowBase = ConstantInt::get(Ty, Mapping.Offset); - return IRB.CreateAdd(Shadow, ShadowBase); + return IRB.CreateGEP(Int8Ty, shadowBase(), Shadow); } -void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, +void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore) { + const int64_t AccessInfo = Recover * 0x20 + IsWrite * 0x10 + AccessSizeIndex; IRBuilder<> IRB(InsertBefore); + + if (!ClInlineAllChecks && TargetTriple.isAArch64() && + TargetTriple.isOSBinFormatELF() && !Recover) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + Ptr = IRB.CreateBitCast(Ptr, Int8PtrTy); + IRB.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::hwasan_check_memaccess), + {shadowBase(), Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); + return; + } + + Value *PtrLong = IRB.CreatePointerCast(Ptr, IntptrTy); Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, kPointerTagShift), IRB.getInt8Ty()); Value *AddrLong = untagPointer(IRB, PtrLong); - Value *ShadowLong = memToShadow(AddrLong, PtrLong->getType(), IRB); - Value *MemTag = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, Int8PtrTy)); + Value *Shadow = memToShadow(AddrLong, IRB); + Value *MemTag = IRB.CreateLoad(Int8Ty, Shadow); Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); int matchAllTag = ClMatchAllTag.getNumOccurrences() > 0 ? @@ -540,11 +587,35 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, } Instruction *CheckTerm = - SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, !Recover, + SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false, MDBuilder(*C).createBranchWeights(1, 100000)); IRB.SetInsertPoint(CheckTerm); - const int64_t AccessInfo = Recover * 0x20 + IsWrite * 0x10 + AccessSizeIndex; + Value *OutOfShortGranuleTagRange = + IRB.CreateICmpUGT(MemTag, ConstantInt::get(Int8Ty, 15)); + Instruction *CheckFailTerm = + SplitBlockAndInsertIfThen(OutOfShortGranuleTagRange, CheckTerm, !Recover, + MDBuilder(*C).createBranchWeights(1, 100000)); + + IRB.SetInsertPoint(CheckTerm); + Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(PtrLong, 15), Int8Ty); + PtrLowBits = IRB.CreateAdd( + PtrLowBits, ConstantInt::get(Int8Ty, (1 << AccessSizeIndex) - 1)); + Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, MemTag); + SplitBlockAndInsertIfThen(PtrLowBitsOOB, CheckTerm, false, + MDBuilder(*C).createBranchWeights(1, 100000), + nullptr, nullptr, CheckFailTerm->getParent()); + + IRB.SetInsertPoint(CheckTerm); + Value *InlineTagAddr = IRB.CreateOr(AddrLong, 15); + InlineTagAddr = IRB.CreateIntToPtr(InlineTagAddr, Int8PtrTy); + Value *InlineTag = IRB.CreateLoad(Int8Ty, InlineTagAddr); + Value *InlineTagMismatch = IRB.CreateICmpNE(PtrTag, InlineTag); + SplitBlockAndInsertIfThen(InlineTagMismatch, CheckTerm, false, + MDBuilder(*C).createBranchWeights(1, 100000), + nullptr, nullptr, CheckFailTerm->getParent()); + + IRB.SetInsertPoint(CheckFailTerm); InlineAsm *Asm; switch (TargetTriple.getArch()) { case Triple::x86_64: @@ -568,6 +639,8 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, report_fatal_error("unsupported architecture"); } IRB.CreateCall(Asm, PtrLong); + if (Recover) + cast<BranchInst>(CheckFailTerm)->setSuccessor(0, CheckTerm->getParent()); } void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { @@ -610,7 +683,6 @@ bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { return false; //FIXME IRBuilder<> IRB(I); - Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (isPowerOf2_64(TypeSize) && (TypeSize / 8 <= (1UL << (kNumberOfAccessSizes - 1))) && (Alignment >= (1UL << Mapping.Scale) || Alignment == 0 || @@ -618,13 +690,14 @@ bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); if (ClInstrumentWithCalls) { IRB.CreateCall(HwasanMemoryAccessCallback[IsWrite][AccessSizeIndex], - AddrLong); + IRB.CreatePointerCast(Addr, IntptrTy)); } else { - instrumentMemAccessInline(AddrLong, IsWrite, AccessSizeIndex, I); + instrumentMemAccessInline(Addr, IsWrite, AccessSizeIndex, I); } } else { IRB.CreateCall(HwasanMemoryAccessCallbackSized[IsWrite], - {AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8)}); + {IRB.CreatePointerCast(Addr, IntptrTy), + ConstantInt::get(IntptrTy, TypeSize / 8)}); } untagPointerOperand(I, Addr); @@ -644,27 +717,33 @@ static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { } bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, - Value *Tag) { - size_t Size = (getAllocaSizeInBytes(*AI) + Mapping.getAllocaAlignment() - 1) & - ~(Mapping.getAllocaAlignment() - 1); + Value *Tag, size_t Size) { + size_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment()); Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty()); if (ClInstrumentWithCalls) { IRB.CreateCall(HwasanTagMemoryFunc, {IRB.CreatePointerCast(AI, Int8PtrTy), JustTag, - ConstantInt::get(IntptrTy, Size)}); + ConstantInt::get(IntptrTy, AlignedSize)}); } else { size_t ShadowSize = Size >> Mapping.Scale; - Value *ShadowPtr = IRB.CreateIntToPtr( - memToShadow(IRB.CreatePointerCast(AI, IntptrTy), AI->getType(), IRB), - Int8PtrTy); + Value *ShadowPtr = memToShadow(IRB.CreatePointerCast(AI, IntptrTy), IRB); // If this memset is not inlined, it will be intercepted in the hwasan // runtime library. That's OK, because the interceptor skips the checks if // the address is in the shadow region. // FIXME: the interceptor is not as fast as real memset. Consider lowering // llvm.memset right here into either a sequence of stores, or a call to // hwasan_tag_memory. - IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, /*Align=*/1); + if (ShadowSize) + IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, /*Align=*/1); + if (Size != AlignedSize) { + IRB.CreateStore( + ConstantInt::get(Int8Ty, Size % Mapping.getAllocaAlignment()), + IRB.CreateConstGEP1_32(Int8Ty, ShadowPtr, ShadowSize)); + IRB.CreateStore(JustTag, IRB.CreateConstGEP1_32( + Int8Ty, IRB.CreateBitCast(AI, Int8PtrTy), + AlignedSize - 1)); + } } return true; } @@ -674,10 +753,16 @@ static unsigned RetagMask(unsigned AllocaNo) { // x = x ^ (mask << 56) can be encoded as a single armv8 instruction for these // masks. // The list does not include the value 255, which is used for UAR. - static unsigned FastMasks[] = { - 0, 1, 2, 3, 4, 6, 7, 8, 12, 14, 15, 16, 24, - 28, 30, 31, 32, 48, 56, 60, 62, 63, 64, 96, 112, 120, - 124, 126, 127, 128, 192, 224, 240, 248, 252, 254}; + // + // Because we are more likely to use earlier elements of this list than later + // ones, it is sorted in increasing order of probability of collision with a + // mask allocated (temporally) nearby. The program that generated this list + // can be found at: + // https://github.com/google/sanitizers/blob/master/hwaddress-sanitizer/sort_masks.py + static unsigned FastMasks[] = {0, 128, 64, 192, 32, 96, 224, 112, 240, + 48, 16, 120, 248, 56, 24, 8, 124, 252, + 60, 28, 12, 4, 126, 254, 62, 30, 14, + 6, 2, 127, 63, 31, 15, 7, 3, 1}; return FastMasks[AllocaNo % (sizeof(FastMasks) / sizeof(FastMasks[0]))]; } @@ -688,6 +773,8 @@ Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) { Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { if (ClGenerateTagsWithCalls) return getNextTagWithCall(IRB); + if (StackBaseTag) + return StackBaseTag; // FIXME: use addressofreturnaddress (but implement it in aarch64 backend // first). Module *M = IRB.GetInsertBlock()->getParent()->getParent(); @@ -763,7 +850,8 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { Function *ThreadPointerFunc = Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); Value *SlotPtr = IRB.CreatePointerCast( - IRB.CreateConstGEP1_32(IRB.CreateCall(ThreadPointerFunc), 0x30), + IRB.CreateConstGEP1_32(IRB.getInt8Ty(), + IRB.CreateCall(ThreadPointerFunc), 0x30), Ty->getPointerTo(0)); return SlotPtr; } @@ -774,45 +862,21 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { return nullptr; } -// Creates a string with a description of the stack frame (set of Allocas). -// The string is intended to be human readable. -// The current form is: Size1 Name1; Size2 Name2; ... -std::string -HWAddressSanitizer::createFrameString(ArrayRef<AllocaInst *> Allocas) { - std::ostringstream Descr; - for (auto AI : Allocas) - Descr << getAllocaSizeInBytes(*AI) << " " << AI->getName().str() << "; "; - return Descr.str(); -} - -// Creates a global in the frame section which consists of two pointers: -// the function PC and the frame string constant. -void HWAddressSanitizer::createFrameGlobal(Function &F, - const std::string &FrameString) { - Module &M = *F.getParent(); - auto DescrGV = createPrivateGlobalForString(M, FrameString, true); - auto PtrPairTy = StructType::get(F.getType(), DescrGV->getType()); - auto GV = new GlobalVariable( - M, PtrPairTy, /*isConstantGlobal*/ true, GlobalVariable::PrivateLinkage, - ConstantStruct::get(PtrPairTy, (Constant *)&F, (Constant *)DescrGV), - "__hwasan"); - GV->setSection(getFrameSection()); - appendToCompilerUsed(M, GV); - // Put GV into the F's Comadat so that if F is deleted GV can be deleted too. - if (auto Comdat = - GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId)) - GV->setComdat(Comdat); -} +void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { + if (!Mapping.InTls) { + LocalDynamicShadow = getDynamicShadowNonTls(IRB); + return; + } -Value *HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, - bool WithFrameRecord) { - if (!Mapping.InTls) - return getDynamicShadowNonTls(IRB); + if (!WithFrameRecord && TargetTriple.isAndroid()) { + LocalDynamicShadow = getDynamicShadowIfunc(IRB); + return; + } Value *SlotPtr = getHwasanThreadSlotPtr(IRB, IntptrTy); assert(SlotPtr); - Instruction *ThreadLong = IRB.CreateLoad(SlotPtr); + Instruction *ThreadLong = IRB.CreateLoad(IntptrTy, SlotPtr); Function *F = IRB.GetInsertBlock()->getParent(); if (F->getFnAttribute("hwasan-abi").getValueAsString() == "interceptor") { @@ -826,7 +890,7 @@ Value *HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, // FIXME: This should call a new runtime function with a custom calling // convention to avoid needing to spill all arguments here. IRB.CreateCall(HwasanThreadEnterFunc); - LoadInst *ReloadThreadLong = IRB.CreateLoad(SlotPtr); + LoadInst *ReloadThreadLong = IRB.CreateLoad(IntptrTy, SlotPtr); IRB.SetInsertPoint(&*Br->getSuccessor(0)->begin()); PHINode *ThreadLongPhi = IRB.CreatePHI(IntptrTy, 2); @@ -840,15 +904,21 @@ Value *HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, TargetTriple.isAArch64() ? ThreadLong : untagPointer(IRB, ThreadLong); if (WithFrameRecord) { + StackBaseTag = IRB.CreateAShr(ThreadLong, 3); + // Prepare ring buffer data. - auto PC = IRB.CreatePtrToInt(F, IntptrTy); + Value *PC; + if (TargetTriple.getArch() == Triple::aarch64) + PC = readRegister(IRB, "pc"); + else + PC = IRB.CreatePtrToInt(F, IntptrTy); auto GetStackPointerFn = Intrinsic::getDeclaration(F->getParent(), Intrinsic::frameaddress); Value *SP = IRB.CreatePtrToInt( IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())}), IntptrTy); - // Mix SP and PC. TODO: also add the tag to the mix. + // Mix SP and PC. // Assumptions: // PC is 0x0000PPPPPPPPPPPP (48 bits are meaningful, others are zero) // SP is 0xsssssssssssSSSS0 (4 lower bits are zero) @@ -879,16 +949,38 @@ Value *HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, // Get shadow base address by aligning RecordPtr up. // Note: this is not correct if the pointer is already aligned. // Runtime library will make sure this never happens. - Value *ShadowBase = IRB.CreateAdd( + LocalDynamicShadow = IRB.CreateAdd( IRB.CreateOr( ThreadLongMaybeUntagged, ConstantInt::get(IntptrTy, (1ULL << kShadowBaseAlignment) - 1)), ConstantInt::get(IntptrTy, 1), "hwasan.shadow"); - return ShadowBase; + LocalDynamicShadow = IRB.CreateIntToPtr(LocalDynamicShadow, Int8PtrTy); +} + +Value *HWAddressSanitizer::readRegister(IRBuilder<> &IRB, StringRef Name) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + Function *ReadRegister = + Intrinsic::getDeclaration(M, Intrinsic::read_register, IntptrTy); + MDNode *MD = MDNode::get(*C, {MDString::get(*C, Name)}); + Value *Args[] = {MetadataAsValue::get(*C, MD)}; + return IRB.CreateCall(ReadRegister, Args); +} + +bool HWAddressSanitizer::instrumentLandingPads( + SmallVectorImpl<Instruction *> &LandingPadVec) { + for (auto *LP : LandingPadVec) { + IRBuilder<> IRB(LP->getNextNode()); + IRB.CreateCall( + HWAsanHandleVfork, + {readRegister(IRB, (TargetTriple.getArch() == Triple::x86_64) ? "rsp" + : "sp")}); + } + return true; } bool HWAddressSanitizer::instrumentStack( SmallVectorImpl<AllocaInst *> &Allocas, + DenseMap<AllocaInst *, std::vector<DbgDeclareInst *>> &AllocaDeclareMap, SmallVectorImpl<Instruction *> &RetVec, Value *StackTag) { // Ideally, we want to calculate tagged stack base pointer, and rewrite all // alloca addresses using that. Unfortunately, offsets are not known yet @@ -913,14 +1005,22 @@ bool HWAddressSanitizer::instrumentStack( U.set(Replacement); } - tagAlloca(IRB, AI, Tag); + for (auto *DDI : AllocaDeclareMap.lookup(AI)) { + DIExpression *OldExpr = DDI->getExpression(); + DIExpression *NewExpr = DIExpression::append( + OldExpr, {dwarf::DW_OP_LLVM_tag_offset, RetagMask(N)}); + DDI->setArgOperand(2, MetadataAsValue::get(*C, NewExpr)); + } + + size_t Size = getAllocaSizeInBytes(*AI); + tagAlloca(IRB, AI, Tag, Size); for (auto RI : RetVec) { IRB.SetInsertPoint(RI); // Re-tag alloca memory with the special UAR tag. Value *Tag = getUARTag(IRB, StackTag); - tagAlloca(IRB, AI, Tag); + tagAlloca(IRB, AI, Tag, alignTo(Size, Mapping.getAllocaAlignment())); } } @@ -943,7 +1043,7 @@ bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { !AI.isSwiftError()); } -bool HWAddressSanitizer::runOnFunction(Function &F) { +bool HWAddressSanitizer::sanitizeFunction(Function &F) { if (&F == HwasanCtorFunction) return false; @@ -955,15 +1055,12 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { SmallVector<Instruction*, 16> ToInstrument; SmallVector<AllocaInst*, 8> AllocasToInstrument; SmallVector<Instruction*, 8> RetVec; + SmallVector<Instruction*, 8> LandingPadVec; + DenseMap<AllocaInst *, std::vector<DbgDeclareInst *>> AllocaDeclareMap; for (auto &BB : F) { for (auto &Inst : BB) { if (ClInstrumentStack) if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { - // Realign all allocas. We don't want small uninteresting allocas to - // hide in instrumented alloca's padding. - if (AI->getAlignment() < Mapping.getAllocaAlignment()) - AI->setAlignment(Mapping.getAllocaAlignment()); - // Instrument some of them. if (isInterestingAlloca(*AI)) AllocasToInstrument.push_back(AI); continue; @@ -973,6 +1070,13 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { isa<CleanupReturnInst>(Inst)) RetVec.push_back(&Inst); + if (auto *DDI = dyn_cast<DbgDeclareInst>(&Inst)) + if (auto *Alloca = dyn_cast_or_null<AllocaInst>(DDI->getAddress())) + AllocaDeclareMap[Alloca].push_back(DDI); + + if (ClInstrumentLandingPads && isa<LandingPadInst>(Inst)) + LandingPadVec.push_back(&Inst); + Value *MaybeMask = nullptr; bool IsWrite; unsigned Alignment; @@ -984,33 +1088,93 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { } } - if (AllocasToInstrument.empty() && ToInstrument.empty()) - return false; + initializeCallbacks(*F.getParent()); - if (ClCreateFrameDescriptions && !AllocasToInstrument.empty()) - createFrameGlobal(F, createFrameString(AllocasToInstrument)); + if (!LandingPadVec.empty()) + instrumentLandingPads(LandingPadVec); - initializeCallbacks(*F.getParent()); + if (AllocasToInstrument.empty() && ToInstrument.empty()) + return false; assert(!LocalDynamicShadow); Instruction *InsertPt = &*F.getEntryBlock().begin(); IRBuilder<> EntryIRB(InsertPt); - LocalDynamicShadow = emitPrologue(EntryIRB, - /*WithFrameRecord*/ ClRecordStackHistory && - !AllocasToInstrument.empty()); + emitPrologue(EntryIRB, + /*WithFrameRecord*/ ClRecordStackHistory && + !AllocasToInstrument.empty()); bool Changed = false; if (!AllocasToInstrument.empty()) { Value *StackTag = ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); - Changed |= instrumentStack(AllocasToInstrument, RetVec, StackTag); + Changed |= instrumentStack(AllocasToInstrument, AllocaDeclareMap, RetVec, + StackTag); + } + + // Pad and align each of the allocas that we instrumented to stop small + // uninteresting allocas from hiding in instrumented alloca's padding and so + // that we have enough space to store real tags for short granules. + DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap; + for (AllocaInst *AI : AllocasToInstrument) { + uint64_t Size = getAllocaSizeInBytes(*AI); + uint64_t AlignedSize = alignTo(Size, Mapping.getAllocaAlignment()); + AI->setAlignment(std::max(AI->getAlignment(), 16u)); + if (Size != AlignedSize) { + Type *AllocatedType = AI->getAllocatedType(); + if (AI->isArrayAllocation()) { + uint64_t ArraySize = + cast<ConstantInt>(AI->getArraySize())->getZExtValue(); + AllocatedType = ArrayType::get(AllocatedType, ArraySize); + } + Type *TypeWithPadding = StructType::get( + AllocatedType, ArrayType::get(Int8Ty, AlignedSize - Size)); + auto *NewAI = new AllocaInst( + TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI); + NewAI->takeName(AI); + NewAI->setAlignment(AI->getAlignment()); + NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca()); + NewAI->setSwiftError(AI->isSwiftError()); + NewAI->copyMetadata(*AI); + auto *Bitcast = new BitCastInst(NewAI, AI->getType(), "", AI); + AI->replaceAllUsesWith(Bitcast); + AllocaToPaddedAllocaMap[AI] = NewAI; + } + } + + if (!AllocaToPaddedAllocaMap.empty()) { + for (auto &BB : F) + for (auto &Inst : BB) + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) + if (auto *AI = + dyn_cast_or_null<AllocaInst>(DVI->getVariableLocation())) + if (auto *NewAI = AllocaToPaddedAllocaMap.lookup(AI)) + DVI->setArgOperand( + 0, MetadataAsValue::get(*C, LocalAsMetadata::get(NewAI))); + for (auto &P : AllocaToPaddedAllocaMap) + P.first->eraseFromParent(); + } + + // If we split the entry block, move any allocas that were originally in the + // entry block back into the entry block so that they aren't treated as + // dynamic allocas. + if (EntryIRB.GetInsertBlock() != &F.getEntryBlock()) { + InsertPt = &*F.getEntryBlock().begin(); + for (auto II = EntryIRB.GetInsertBlock()->begin(), + IE = EntryIRB.GetInsertBlock()->end(); + II != IE;) { + Instruction *I = &*II++; + if (auto *AI = dyn_cast<AllocaInst>(I)) + if (isa<ConstantInt>(AI->getArraySize())) + I->moveBefore(InsertPt); + } } for (auto Inst : ToInstrument) Changed |= instrumentMemAccess(Inst); LocalDynamicShadow = nullptr; + StackBaseTag = nullptr; return Changed; } diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 58436c8560ad..c7371f567ff3 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -1,9 +1,8 @@ //===- IndirectCallPromotion.cpp - Optimizations based on value profiling -===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -239,7 +238,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( LLVM_DEBUG(dbgs() << " Candidate " << I << " Count=" << Count << " Target_func: " << Target << "\n"); - if (ICPInvokeOnly && dyn_cast<CallInst>(Inst)) { + if (ICPInvokeOnly && isa<CallInst>(Inst)) { LLVM_DEBUG(dbgs() << " Not promote: User options.\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) @@ -247,7 +246,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( }); break; } - if (ICPCallOnly && dyn_cast<InvokeInst>(Inst)) { + if (ICPCallOnly && isa<InvokeInst>(Inst)) { LLVM_DEBUG(dbgs() << " Not promote: User option.\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) @@ -311,10 +310,10 @@ Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst, promoteCallWithIfThenElse(CallSite(Inst), DirectCallee, BranchWeights); if (AttachProfToDirectCall) { - SmallVector<uint32_t, 1> Weights; - Weights.push_back(Count); MDBuilder MDB(NewInst->getContext()); - NewInst->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + NewInst->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights({static_cast<uint32_t>(Count)})); } using namespace ore; @@ -394,9 +393,7 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, } bool Changed = false; for (auto &F : M) { - if (F.isDeclaration()) - continue; - if (F.hasFnAttribute(Attribute::OptimizeNone)) + if (F.isDeclaration() || F.hasOptNone()) continue; std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; diff --git a/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/lib/Transforms/Instrumentation/InstrOrderFile.cpp new file mode 100644 index 000000000000..a2c1ddfd279e --- /dev/null +++ b/lib/Transforms/Instrumentation/InstrOrderFile.cpp @@ -0,0 +1,211 @@ +//===- InstrOrderFile.cpp ---- Late IR instrumentation for order file ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/InstrOrderFile.h" +#include <fstream> +#include <map> +#include <mutex> +#include <set> +#include <sstream> + +using namespace llvm; +#define DEBUG_TYPE "instrorderfile" + +static cl::opt<std::string> ClOrderFileWriteMapping( + "orderfile-write-mapping", cl::init(""), + cl::desc( + "Dump functions and their MD5 hash to deobfuscate profile data"), + cl::Hidden); + +namespace { + +// We need a global bitmap to tell if a function is executed. We also +// need a global variable to save the order of functions. We can use a +// fixed-size buffer that saves the MD5 hash of the function. We need +// a global variable to save the index into the buffer. + +std::mutex MappingMutex; + +struct InstrOrderFile { +private: + GlobalVariable *OrderFileBuffer; + GlobalVariable *BufferIdx; + GlobalVariable *BitMap; + ArrayType *BufferTy; + ArrayType *MapTy; + +public: + InstrOrderFile() {} + + void createOrderFileData(Module &M) { + LLVMContext &Ctx = M.getContext(); + int NumFunctions = 0; + for (Function &F : M) { + if (!F.isDeclaration()) + NumFunctions++; + } + + BufferTy = + ArrayType::get(Type::getInt64Ty(Ctx), INSTR_ORDER_FILE_BUFFER_SIZE); + Type *IdxTy = Type::getInt32Ty(Ctx); + MapTy = ArrayType::get(Type::getInt8Ty(Ctx), NumFunctions); + + // Create the global variables. + std::string SymbolName = INSTR_PROF_ORDERFILE_BUFFER_NAME_STR; + OrderFileBuffer = new GlobalVariable(M, BufferTy, false, GlobalValue::LinkOnceODRLinkage, + Constant::getNullValue(BufferTy), SymbolName); + Triple TT = Triple(M.getTargetTriple()); + OrderFileBuffer->setSection( + getInstrProfSectionName(IPSK_orderfile, TT.getObjectFormat())); + + std::string IndexName = INSTR_PROF_ORDERFILE_BUFFER_IDX_NAME_STR; + BufferIdx = new GlobalVariable(M, IdxTy, false, GlobalValue::LinkOnceODRLinkage, + Constant::getNullValue(IdxTy), IndexName); + + std::string BitMapName = "bitmap_0"; + BitMap = new GlobalVariable(M, MapTy, false, GlobalValue::PrivateLinkage, + Constant::getNullValue(MapTy), BitMapName); + } + + // Generate the code sequence in the entry block of each function to + // update the buffer. + void generateCodeSequence(Module &M, Function &F, int FuncId) { + if (!ClOrderFileWriteMapping.empty()) { + std::lock_guard<std::mutex> LogLock(MappingMutex); + std::error_code EC; + llvm::raw_fd_ostream OS(ClOrderFileWriteMapping, EC, llvm::sys::fs::F_Append); + if (EC) { + report_fatal_error(Twine("Failed to open ") + ClOrderFileWriteMapping + + " to save mapping file for order file instrumentation\n"); + } else { + std::stringstream stream; + stream << std::hex << MD5Hash(F.getName()); + std::string singleLine = "MD5 " + stream.str() + " " + + std::string(F.getName()) + '\n'; + OS << singleLine; + } + } + + BasicBlock *OrigEntry = &F.getEntryBlock(); + + LLVMContext &Ctx = M.getContext(); + IntegerType *Int32Ty = Type::getInt32Ty(Ctx); + IntegerType *Int8Ty = Type::getInt8Ty(Ctx); + + // Create a new entry block for instrumentation. We will check the bitmap + // in this basic block. + BasicBlock *NewEntry = + BasicBlock::Create(M.getContext(), "order_file_entry", &F, OrigEntry); + IRBuilder<> entryB(NewEntry); + // Create a basic block for updating the circular buffer. + BasicBlock *UpdateOrderFileBB = + BasicBlock::Create(M.getContext(), "order_file_set", &F, OrigEntry); + IRBuilder<> updateB(UpdateOrderFileBB); + + // Check the bitmap, if it is already 1, do nothing. + // Otherwise, set the bit, grab the index, update the buffer. + Value *IdxFlags[] = {ConstantInt::get(Int32Ty, 0), + ConstantInt::get(Int32Ty, FuncId)}; + Value *MapAddr = entryB.CreateGEP(MapTy, BitMap, IdxFlags, ""); + LoadInst *loadBitMap = entryB.CreateLoad(Int8Ty, MapAddr, ""); + entryB.CreateStore(ConstantInt::get(Int8Ty, 1), MapAddr); + Value *IsNotExecuted = + entryB.CreateICmpEQ(loadBitMap, ConstantInt::get(Int8Ty, 0)); + entryB.CreateCondBr(IsNotExecuted, UpdateOrderFileBB, OrigEntry); + + // Fill up UpdateOrderFileBB: grab the index, update the buffer! + Value *IdxVal = updateB.CreateAtomicRMW( + AtomicRMWInst::Add, BufferIdx, ConstantInt::get(Int32Ty, 1), + AtomicOrdering::SequentiallyConsistent); + // We need to wrap around the index to fit it inside the buffer. + Value *WrappedIdx = updateB.CreateAnd( + IdxVal, ConstantInt::get(Int32Ty, INSTR_ORDER_FILE_BUFFER_MASK)); + Value *BufferGEPIdx[] = {ConstantInt::get(Int32Ty, 0), WrappedIdx}; + Value *BufferAddr = + updateB.CreateGEP(BufferTy, OrderFileBuffer, BufferGEPIdx, ""); + updateB.CreateStore(ConstantInt::get(Type::getInt64Ty(Ctx), MD5Hash(F.getName())), + BufferAddr); + updateB.CreateBr(OrigEntry); + } + + bool run(Module &M) { + createOrderFileData(M); + + int FuncId = 0; + for (Function &F : M) { + if (F.isDeclaration()) + continue; + generateCodeSequence(M, F, FuncId); + ++FuncId; + } + + return true; + } + +}; // End of InstrOrderFile struct + +class InstrOrderFileLegacyPass : public ModulePass { +public: + static char ID; + + InstrOrderFileLegacyPass() : ModulePass(ID) { + initializeInstrOrderFileLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; +}; + +} // End anonymous namespace + +bool InstrOrderFileLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + + return InstrOrderFile().run(M); +} + +PreservedAnalyses +InstrOrderFilePass::run(Module &M, ModuleAnalysisManager &AM) { + if (InstrOrderFile().run(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +INITIALIZE_PASS_BEGIN(InstrOrderFileLegacyPass, "instrorderfile", + "Instrumentation for Order File", false, false) +INITIALIZE_PASS_END(InstrOrderFileLegacyPass, "instrorderfile", + "Instrumentation for Order File", false, false) + +char InstrOrderFileLegacyPass::ID = 0; + +ModulePass *llvm::createInstrOrderFilePass() { + return new InstrOrderFileLegacyPass(); +} diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index 15b94388cbe5..63c2b8078967 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -1,9 +1,8 @@ //===-- InstrProfiling.cpp - Frontend instrumentation based profiling -----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -19,6 +18,8 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Attributes.h" @@ -148,8 +149,8 @@ public: static char ID; InstrProfilingLegacyPass() : ModulePass(ID) {} - InstrProfilingLegacyPass(const InstrProfOptions &Options) - : ModulePass(ID), InstrProf(Options) {} + InstrProfilingLegacyPass(const InstrProfOptions &Options, bool IsCS = false) + : ModulePass(ID), InstrProf(Options, IsCS) {} StringRef getPassName() const override { return "Frontend instrumentation-based coverage lowering"; @@ -187,7 +188,7 @@ public: SSA.AddAvailableValue(PH, Init); } - void doExtraRewritesBeforeFinalDeletion() const override { + void doExtraRewritesBeforeFinalDeletion() override { for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { BasicBlock *ExitBlock = ExitBlocks[i]; Instruction *InsertPos = InsertPts[i]; @@ -196,6 +197,7 @@ public: // block. Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); Value *Addr = cast<StoreInst>(Store)->getPointerOperand(); + Type *Ty = LiveInValue->getType(); IRBuilder<> Builder(InsertPos); if (AtomicCounterUpdatePromoted) // automic update currently can only be promoted across the current @@ -203,7 +205,7 @@ public: Builder.CreateAtomicRMW(AtomicRMWInst::Add, Addr, LiveInValue, AtomicOrdering::SequentiallyConsistent); else { - LoadInst *OldVal = Builder.CreateLoad(Addr, "pgocount.promoted"); + LoadInst *OldVal = Builder.CreateLoad(Ty, Addr, "pgocount.promoted"); auto *NewVal = Builder.CreateAdd(OldVal, LiveInValue); auto *NewStore = Builder.CreateStore(NewVal, Addr); @@ -232,9 +234,9 @@ class PGOCounterPromoter { public: PGOCounterPromoter( DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCands, - Loop &CurLoop, LoopInfo &LI) + Loop &CurLoop, LoopInfo &LI, BlockFrequencyInfo *BFI) : LoopToCandidates(LoopToCands), ExitBlocks(), InsertPts(), L(CurLoop), - LI(LI) { + LI(LI), BFI(BFI) { SmallVector<BasicBlock *, 8> LoopExitBlocks; SmallPtrSet<BasicBlock *, 8> BlockSet; @@ -263,6 +265,20 @@ public: SSAUpdater SSA(&NewPHIs); Value *InitVal = ConstantInt::get(Cand.first->getType(), 0); + // If BFI is set, we will use it to guide the promotions. + if (BFI) { + auto *BB = Cand.first->getParent(); + auto InstrCount = BFI->getBlockProfileCount(BB); + if (!InstrCount) + continue; + auto PreheaderCount = BFI->getBlockProfileCount(L.getLoopPreheader()); + // If the average loop trip count is not greater than 1.5, we skip + // promotion. + if (PreheaderCount && + (PreheaderCount.getValue() * 3) >= (InstrCount.getValue() * 2)) + continue; + } + PGOCounterPromoterHelper Promoter(Cand.first, Cand.second, SSA, InitVal, L.getLoopPreheader(), ExitBlocks, InsertPts, LoopToCandidates, LI); @@ -312,6 +328,11 @@ private: SmallVector<BasicBlock *, 8> ExitingBlocks; LP->getExitingBlocks(ExitingBlocks); + + // If BFI is set, we do more aggressive promotions based on BFI. + if (BFI) + return (unsigned)-1; + // Not considierered speculative. if (ExitingBlocks.size() == 1) return MaxNumOfPromotionsPerLoop; @@ -343,6 +364,7 @@ private: SmallVector<Instruction *, 8> InsertPts; Loop &L; LoopInfo &LI; + BlockFrequencyInfo *BFI; }; } // end anonymous namespace @@ -365,8 +387,9 @@ INITIALIZE_PASS_END( "Frontend instrumentation-based coverage lowering.", false, false) ModulePass * -llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) { - return new InstrProfilingLegacyPass(Options); +llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options, + bool IsCS) { + return new InstrProfilingLegacyPass(Options, IsCS); } static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { @@ -415,6 +438,13 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) { LoopInfo LI(DT); DenseMap<Loop *, SmallVector<LoadStorePair, 8>> LoopPromotionCandidates; + std::unique_ptr<BlockFrequencyInfo> BFI; + if (Options.UseBFIInPromotion) { + std::unique_ptr<BranchProbabilityInfo> BPI; + BPI.reset(new BranchProbabilityInfo(*F, LI, TLI)); + BFI.reset(new BlockFrequencyInfo(*F, *BPI, LI)); + } + for (const auto &LoadStore : PromotionCandidates) { auto *CounterLoad = LoadStore.first; auto *CounterStore = LoadStore.second; @@ -430,7 +460,7 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) { // Do a post-order traversal of the loops so that counter updates can be // iteratively hoisted outside the loop nest. for (auto *Loop : llvm::reverse(Loops)) { - PGOCounterPromoter Promoter(LoopPromotionCandidates, *Loop, LI); + PGOCounterPromoter Promoter(LoopPromotionCandidates, *Loop, LI, BFI.get()); Promoter.run(&TotalCountersPromoted); } } @@ -509,13 +539,16 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { return true; } -static Constant *getOrInsertValueProfilingCall(Module &M, - const TargetLibraryInfo &TLI, - bool IsRange = false) { +static FunctionCallee +getOrInsertValueProfilingCall(Module &M, const TargetLibraryInfo &TLI, + bool IsRange = false) { LLVMContext &Ctx = M.getContext(); auto *ReturnTy = Type::getVoidTy(M.getContext()); - Constant *Res; + AttributeList AL; + if (auto AK = TLI.getExtAttrForI32Param(false)) + AL = AL.addParamAttribute(M.getContext(), 2, AK); + if (!IsRange) { Type *ParamTypes[] = { #define VALUE_PROF_FUNC_PARAM(ParamType, ParamName, ParamLLVMType) ParamLLVMType @@ -523,8 +556,8 @@ static Constant *getOrInsertValueProfilingCall(Module &M, }; auto *ValueProfilingCallTy = FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); - Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), - ValueProfilingCallTy); + return M.getOrInsertFunction(getInstrProfValueProfFuncName(), + ValueProfilingCallTy, AL); } else { Type *RangeParamTypes[] = { #define VALUE_RANGE_PROF 1 @@ -534,15 +567,9 @@ static Constant *getOrInsertValueProfilingCall(Module &M, }; auto *ValueRangeProfilingCallTy = FunctionType::get(ReturnTy, makeArrayRef(RangeParamTypes), false); - Res = M.getOrInsertFunction(getInstrProfValueRangeProfFuncName(), - ValueRangeProfilingCallTy); + return M.getOrInsertFunction(getInstrProfValueRangeProfFuncName(), + ValueRangeProfilingCallTy, AL); } - - if (Function *FunRes = dyn_cast<Function>(Res)) { - if (auto AK = TLI.getExtAttrForI32Param(false)) - FunRes->addParamAttr(2, AK); - } - return Res; } void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) { @@ -601,13 +628,15 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { IRBuilder<> Builder(Inc); uint64_t Index = Inc->getIndex()->getZExtValue(); - Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); + Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters->getValueType(), + Counters, 0, Index); if (Options.Atomic || AtomicCounterUpdateAll) { Builder.CreateAtomicRMW(AtomicRMWInst::Add, Addr, Inc->getStep(), AtomicOrdering::Monotonic); } else { - Value *Load = Builder.CreateLoad(Addr, "pgocount"); + Value *IncStep = Inc->getStep(); + Value *Load = Builder.CreateLoad(IncStep->getType(), Addr, "pgocount"); auto *Count = Builder.CreateAdd(Load, Inc->getStep()); auto *Store = Builder.CreateStore(Count, Addr); if (isCounterPromotionEnabled()) @@ -678,32 +707,14 @@ static inline bool shouldRecordFunctionAddr(Function *F) { return F->hasAddressTaken() || F->hasLinkOnceLinkage(); } -static inline Comdat *getOrCreateProfileComdat(Module &M, Function &F, - InstrProfIncrementInst *Inc) { - if (!needsComdatForCounter(F, M)) - return nullptr; - - // COFF format requires a COMDAT section to have a key symbol with the same - // name. The linker targeting COFF also requires that the COMDAT - // a section is associated to must precede the associating section. For this - // reason, we must choose the counter var's name as the name of the comdat. - StringRef ComdatPrefix = (Triple(M.getTargetTriple()).isOSBinFormatCOFF() - ? getInstrProfCountersVarPrefix() - : getInstrProfComdatPrefix()); - return M.getOrInsertComdat(StringRef(getVarName(Inc, ComdatPrefix))); -} - -static bool needsRuntimeRegistrationOfSectionRange(const Module &M) { +static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { // Don't do this for Darwin. compiler-rt uses linker magic. - if (Triple(M.getTargetTriple()).isOSDarwin()) + if (TT.isOSDarwin()) return false; - // Use linker script magic to get data/cnts/name start/end. - if (Triple(M.getTargetTriple()).isOSLinux() || - Triple(M.getTargetTriple()).isOSFreeBSD() || - Triple(M.getTargetTriple()).isOSNetBSD() || - Triple(M.getTargetTriple()).isOSFuchsia() || - Triple(M.getTargetTriple()).isPS4CPU()) + if (TT.isOSLinux() || TT.isOSFreeBSD() || TT.isOSNetBSD() || + TT.isOSSolaris() || TT.isOSFuchsia() || TT.isPS4CPU() || + TT.isOSWindows()) return false; return true; @@ -720,13 +731,37 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { PD = It->second; } - // Move the name variable to the right section. Place them in a COMDAT group - // if the associated function is a COMDAT. This will make sure that - // only one copy of counters of the COMDAT function will be emitted after - // linking. + // Match the linkage and visibility of the name global, except on COFF, where + // the linkage must be local and consequentially the visibility must be + // default. Function *Fn = Inc->getParent()->getParent(); - Comdat *ProfileVarsComdat = nullptr; - ProfileVarsComdat = getOrCreateProfileComdat(*M, *Fn, Inc); + GlobalValue::LinkageTypes Linkage = NamePtr->getLinkage(); + GlobalValue::VisibilityTypes Visibility = NamePtr->getVisibility(); + if (TT.isOSBinFormatCOFF()) { + Linkage = GlobalValue::InternalLinkage; + Visibility = GlobalValue::DefaultVisibility; + } + + // Move the name variable to the right section. Place them in a COMDAT group + // if the associated function is a COMDAT. This will make sure that only one + // copy of counters of the COMDAT function will be emitted after linking. Keep + // in mind that this pass may run before the inliner, so we need to create a + // new comdat group for the counters and profiling data. If we use the comdat + // of the parent function, that will result in relocations against discarded + // sections. + Comdat *Cmdt = nullptr; + GlobalValue::LinkageTypes CounterLinkage = Linkage; + if (needsComdatForCounter(*Fn, *M)) { + StringRef CmdtPrefix = getInstrProfComdatPrefix(); + if (TT.isOSBinFormatCOFF()) { + // For COFF, the comdat group name must be the name of a symbol in the + // group. Use the counter variable name, and upgrade its linkage to + // something externally visible, like linkonce_odr. + CmdtPrefix = getInstrProfCountersVarPrefix(); + CounterLinkage = GlobalValue::LinkOnceODRLinkage; + } + Cmdt = M->getOrInsertComdat(getVarName(Inc, CmdtPrefix)); + } uint64_t NumCounters = Inc->getNumCounters()->getZExtValue(); LLVMContext &Ctx = M->getContext(); @@ -734,20 +769,21 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { // Create the counters variable. auto *CounterPtr = - new GlobalVariable(*M, CounterTy, false, NamePtr->getLinkage(), + new GlobalVariable(*M, CounterTy, false, Linkage, Constant::getNullValue(CounterTy), getVarName(Inc, getInstrProfCountersVarPrefix())); - CounterPtr->setVisibility(NamePtr->getVisibility()); + CounterPtr->setVisibility(Visibility); CounterPtr->setSection( getInstrProfSectionName(IPSK_cnts, TT.getObjectFormat())); CounterPtr->setAlignment(8); - CounterPtr->setComdat(ProfileVarsComdat); + CounterPtr->setComdat(Cmdt); + CounterPtr->setLinkage(CounterLinkage); auto *Int8PtrTy = Type::getInt8PtrTy(Ctx); // Allocate statically the array of pointers to value profile nodes for // the current function. Constant *ValuesPtrExpr = ConstantPointerNull::get(Int8PtrTy); - if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(*M)) { + if (ValueProfileStaticAlloc && !needsRuntimeRegistrationOfSectionRange(TT)) { uint64_t NS = 0; for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) NS += PD.NumValueSites[Kind]; @@ -755,14 +791,14 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { ArrayType *ValuesTy = ArrayType::get(Type::getInt64Ty(Ctx), NS); auto *ValuesVar = - new GlobalVariable(*M, ValuesTy, false, NamePtr->getLinkage(), + new GlobalVariable(*M, ValuesTy, false, Linkage, Constant::getNullValue(ValuesTy), getVarName(Inc, getInstrProfValuesVarPrefix())); - ValuesVar->setVisibility(NamePtr->getVisibility()); + ValuesVar->setVisibility(Visibility); ValuesVar->setSection( getInstrProfSectionName(IPSK_vals, TT.getObjectFormat())); ValuesVar->setAlignment(8); - ValuesVar->setComdat(ProfileVarsComdat); + ValuesVar->setComdat(Cmdt); ValuesPtrExpr = ConstantExpr::getBitCast(ValuesVar, Type::getInt8PtrTy(Ctx)); } @@ -789,13 +825,13 @@ InstrProfiling::getOrCreateRegionCounters(InstrProfIncrementInst *Inc) { #define INSTR_PROF_DATA(Type, LLVMType, Name, Init) Init, #include "llvm/ProfileData/InstrProfData.inc" }; - auto *Data = new GlobalVariable(*M, DataTy, false, NamePtr->getLinkage(), + auto *Data = new GlobalVariable(*M, DataTy, false, Linkage, ConstantStruct::get(DataTy, DataVals), getVarName(Inc, getInstrProfDataVarPrefix())); - Data->setVisibility(NamePtr->getVisibility()); + Data->setVisibility(Visibility); Data->setSection(getInstrProfSectionName(IPSK_data, TT.getObjectFormat())); Data->setAlignment(INSTR_PROF_DATA_ALIGNMENT); - Data->setComdat(ProfileVarsComdat); + Data->setComdat(Cmdt); PD.RegionCounters = CounterPtr; PD.DataVar = Data; @@ -820,7 +856,7 @@ void InstrProfiling::emitVNodes() { // For now only support this on platforms that do // not require runtime registration to discover // named section start/end. - if (needsRuntimeRegistrationOfSectionRange(*M)) + if (needsRuntimeRegistrationOfSectionRange(TT)) return; size_t TotalNS = 0; @@ -881,6 +917,10 @@ void InstrProfiling::emitNameData() { NamesSize = CompressedNameStr.size(); NamesVar->setSection( getInstrProfSectionName(IPSK_name, TT.getObjectFormat())); + // On COFF, it's important to reduce the alignment down to 1 to prevent the + // linker from inserting padding before the start of the names section or + // between names entries. + NamesVar->setAlignment(1); UsedVars.push_back(NamesVar); for (auto *NamePtr : ReferencedNames) @@ -888,7 +928,7 @@ void InstrProfiling::emitNameData() { } void InstrProfiling::emitRegistration() { - if (!needsRuntimeRegistrationOfSectionRange(*M)) + if (!needsRuntimeRegistrationOfSectionRange(TT)) return; // Construct the function. @@ -929,7 +969,7 @@ void InstrProfiling::emitRegistration() { bool InstrProfiling::emitRuntimeHook() { // We expect the linker to be invoked with -u<hook_var> flag for linux, // for which case there is no need to emit the user function. - if (Triple(M->getTargetTriple()).isOSLinux()) + if (TT.isOSLinux()) return false; // If the module's provided its own runtime, we don't need to do anything. @@ -950,11 +990,11 @@ bool InstrProfiling::emitRuntimeHook() { if (Options.NoRedZone) User->addFnAttr(Attribute::NoRedZone); User->setVisibility(GlobalValue::HiddenVisibility); - if (Triple(M->getTargetTriple()).supportsCOMDAT()) + if (TT.supportsCOMDAT()) User->setComdat(M->getOrInsertComdat(User->getName())); IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", User)); - auto *Load = IRB.CreateLoad(Var); + auto *Load = IRB.CreateLoad(Int32Ty, Var); IRB.CreateRet(Load); // Mark the user variable as used so that it isn't stripped out. @@ -968,23 +1008,13 @@ void InstrProfiling::emitUses() { } void InstrProfiling::emitInitialization() { - StringRef InstrProfileOutput = Options.InstrProfileOutput; - - if (!InstrProfileOutput.empty()) { - // Create variable for profile name. - Constant *ProfileNameConst = - ConstantDataArray::getString(M->getContext(), InstrProfileOutput, true); - GlobalVariable *ProfileNameVar = new GlobalVariable( - *M, ProfileNameConst->getType(), true, GlobalValue::WeakAnyLinkage, - ProfileNameConst, INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)); - if (TT.supportsCOMDAT()) { - ProfileNameVar->setLinkage(GlobalValue::ExternalLinkage); - ProfileNameVar->setComdat(M->getOrInsertComdat( - StringRef(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)))); - } - } - - Constant *RegisterF = M->getFunction(getInstrProfRegFuncsName()); + // Create ProfileFileName variable. Don't don't this for the + // context-sensitive instrumentation lowering: This lowering is after + // LTO/ThinLTO linking. Pass PGOInstrumentationGenCreateVar should + // have already create the variable before LTO/ThinLTO linking. + if (!IsCS) + createProfileFileNameVar(*M, Options.InstrProfileOutput); + Function *RegisterF = M->getFunction(getInstrProfRegFuncsName()); if (!RegisterF) return; @@ -1000,8 +1030,7 @@ void InstrProfiling::emitInitialization() { // Add the basic block and the necessary calls. IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", F)); - if (RegisterF) - IRB.CreateCall(RegisterF, {}); + IRB.CreateCall(RegisterF, {}); IRB.CreateRetVoid(); appendToGlobalCtors(*M, F, 0); diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index c3e323613c70..f56a1bd91b89 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -1,9 +1,8 @@ //===-- Instrumentation.cpp - TransformUtils Infrastructure ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -25,10 +24,12 @@ using namespace llvm; /// Moves I before IP. Returns new insert point. static BasicBlock::iterator moveBeforeInsertPoint(BasicBlock::iterator I, BasicBlock::iterator IP) { // If I is IP, move the insert point down. - if (I == IP) - return ++IP; - // Otherwise, move I before IP and return IP. - I->moveBefore(&*IP); + if (I == IP) { + ++IP; + } else { + // Otherwise, move I before IP and return IP. + I->moveBefore(&*IP); + } return IP; } @@ -101,8 +102,8 @@ Comdat *llvm::GetOrCreateFunctionComdat(Function &F, Triple &T, /// initializeInstrumentation - Initialize all passes in the TransformUtils /// library. void llvm::initializeInstrumentation(PassRegistry &Registry) { - initializeAddressSanitizerPass(Registry); - initializeAddressSanitizerModulePass(Registry); + initializeAddressSanitizerLegacyPassPass(Registry); + initializeModuleAddressSanitizerLegacyPassPass(Registry); initializeBoundsCheckingLegacyPassPass(Registry); initializeControlHeightReductionLegacyPassPass(Registry); initializeGCOVProfilerLegacyPassPass(Registry); @@ -110,13 +111,13 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializePGOInstrumentationUseLegacyPassPass(Registry); initializePGOIndirectCallPromotionLegacyPassPass(Registry); initializePGOMemOPSizeOptLegacyPassPass(Registry); + initializeInstrOrderFileLegacyPassPass(Registry); initializeInstrProfilingLegacyPassPass(Registry); initializeMemorySanitizerLegacyPassPass(Registry); - initializeHWAddressSanitizerPass(Registry); + initializeHWAddressSanitizerLegacyPassPass(Registry); initializeThreadSanitizerLegacyPassPass(Registry); initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); - initializeEfficiencySanitizerPass(Registry); } /// LLVMInitializeInstrumentation - C binding for diff --git a/lib/Transforms/Instrumentation/MaximumSpanningTree.h b/lib/Transforms/Instrumentation/MaximumSpanningTree.h index 4eb758c69c58..892a6a26da91 100644 --- a/lib/Transforms/Instrumentation/MaximumSpanningTree.h +++ b/lib/Transforms/Instrumentation/MaximumSpanningTree.h @@ -1,9 +1,8 @@ //===- llvm/Analysis/MaximumSpanningTree.h - Interface ----------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -68,8 +67,7 @@ namespace llvm { /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a /// spanning tree. MaximumSpanningTree(EdgeWeights &EdgeVector) { - - std::stable_sort(EdgeVector.begin(), EdgeVector.end(), EdgeWeightCompare()); + llvm::stable_sort(EdgeVector, EdgeWeightCompare()); // Create spanning tree, Forest contains a special data structure // that makes checking if two nodes are already in a common (sub-)tree diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index e6573af2077d..b25cbed1bb02 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -1,9 +1,8 @@ //===- MemorySanitizer.cpp - detector of uninitialized reads --------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -144,6 +143,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -248,6 +248,13 @@ static cl::opt<bool> ClHandleICmpExact("msan-handle-icmp-exact", cl::desc("exact handling of relational integer ICmp"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClHandleLifetimeIntrinsics( + "msan-handle-lifetime-intrinsics", + cl::desc( + "when possible, poison scoped variables at the beginning of the scope " + "(slower, but more precise)"), + cl::Hidden, cl::init(true)); + // When compiling the Linux kernel, we sometimes see false positives related to // MSan being unable to understand that inline assembly calls may initialize // local variables. @@ -305,22 +312,23 @@ static cl::opt<bool> ClWithComdat("msan-with-comdat", // These options allow to specify custom memory map parameters // See MemoryMapParams for details. -static cl::opt<unsigned long long> ClAndMask("msan-and-mask", - cl::desc("Define custom MSan AndMask"), - cl::Hidden, cl::init(0)); +static cl::opt<uint64_t> ClAndMask("msan-and-mask", + cl::desc("Define custom MSan AndMask"), + cl::Hidden, cl::init(0)); -static cl::opt<unsigned long long> ClXorMask("msan-xor-mask", - cl::desc("Define custom MSan XorMask"), - cl::Hidden, cl::init(0)); +static cl::opt<uint64_t> ClXorMask("msan-xor-mask", + cl::desc("Define custom MSan XorMask"), + cl::Hidden, cl::init(0)); -static cl::opt<unsigned long long> ClShadowBase("msan-shadow-base", - cl::desc("Define custom MSan ShadowBase"), - cl::Hidden, cl::init(0)); +static cl::opt<uint64_t> ClShadowBase("msan-shadow-base", + cl::desc("Define custom MSan ShadowBase"), + cl::Hidden, cl::init(0)); -static cl::opt<unsigned long long> ClOriginBase("msan-origin-base", - cl::desc("Define custom MSan OriginBase"), - cl::Hidden, cl::init(0)); +static cl::opt<uint64_t> ClOriginBase("msan-origin-base", + cl::desc("Define custom MSan OriginBase"), + cl::Hidden, cl::init(0)); +static const char *const kMsanModuleCtorName = "msan.module_ctor"; static const char *const kMsanInitName = "__msan_init"; namespace { @@ -454,17 +462,16 @@ namespace { /// the module. class MemorySanitizer { public: - MemorySanitizer(Module &M, int TrackOrigins = 0, bool Recover = false, - bool EnableKmsan = false) { + MemorySanitizer(Module &M, MemorySanitizerOptions Options) { this->CompileKernel = - ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : EnableKmsan; + ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : Options.Kernel; if (ClTrackOrigins.getNumOccurrences() > 0) this->TrackOrigins = ClTrackOrigins; else - this->TrackOrigins = this->CompileKernel ? 2 : TrackOrigins; + this->TrackOrigins = this->CompileKernel ? 2 : Options.TrackOrigins; this->Recover = ClKeepGoing.getNumOccurrences() > 0 ? ClKeepGoing - : (this->CompileKernel | Recover); + : (this->CompileKernel | Options.Recover); initializeModule(M); } @@ -536,41 +543,42 @@ private: bool CallbacksInitialized = false; /// The run-time callback to print a warning. - Value *WarningFn; + FunctionCallee WarningFn; // These arrays are indexed by log2(AccessSize). - Value *MaybeWarningFn[kNumberOfAccessSizes]; - Value *MaybeStoreOriginFn[kNumberOfAccessSizes]; + FunctionCallee MaybeWarningFn[kNumberOfAccessSizes]; + FunctionCallee MaybeStoreOriginFn[kNumberOfAccessSizes]; /// Run-time helper that generates a new origin value for a stack /// allocation. - Value *MsanSetAllocaOrigin4Fn; + FunctionCallee MsanSetAllocaOrigin4Fn; /// Run-time helper that poisons stack on function entry. - Value *MsanPoisonStackFn; + FunctionCallee MsanPoisonStackFn; /// Run-time helper that records a store (or any event) of an /// uninitialized value and returns an updated origin id encoding this info. - Value *MsanChainOriginFn; + FunctionCallee MsanChainOriginFn; /// MSan runtime replacements for memmove, memcpy and memset. - Value *MemmoveFn, *MemcpyFn, *MemsetFn; + FunctionCallee MemmoveFn, MemcpyFn, MemsetFn; /// KMSAN callback for task-local function argument shadow. - Value *MsanGetContextStateFn; + StructType *MsanContextStateTy; + FunctionCallee MsanGetContextStateFn; /// Functions for poisoning/unpoisoning local variables - Value *MsanPoisonAllocaFn, *MsanUnpoisonAllocaFn; + FunctionCallee MsanPoisonAllocaFn, MsanUnpoisonAllocaFn; /// Each of the MsanMetadataPtrXxx functions returns a pair of shadow/origin /// pointers. - Value *MsanMetadataPtrForLoadN, *MsanMetadataPtrForStoreN; - Value *MsanMetadataPtrForLoad_1_8[4]; - Value *MsanMetadataPtrForStore_1_8[4]; - Value *MsanInstrumentAsmStoreFn; + FunctionCallee MsanMetadataPtrForLoadN, MsanMetadataPtrForStoreN; + FunctionCallee MsanMetadataPtrForLoad_1_8[4]; + FunctionCallee MsanMetadataPtrForStore_1_8[4]; + FunctionCallee MsanInstrumentAsmStoreFn; /// Helper to choose between different MsanMetadataPtrXxx(). - Value *getKmsanShadowOriginAccessFn(bool isStore, int size); + FunctionCallee getKmsanShadowOriginAccessFn(bool isStore, int size); /// Memory map parameters used in application-to-shadow calculation. const MemoryMapParams *MapParams; @@ -586,6 +594,8 @@ private: /// An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; + + Function *MsanCtorFunction; }; /// A legacy function pass for msan instrumentation. @@ -595,10 +605,8 @@ struct MemorySanitizerLegacyPass : public FunctionPass { // Pass identification, replacement for typeid. static char ID; - MemorySanitizerLegacyPass(int TrackOrigins = 0, bool Recover = false, - bool EnableKmsan = false) - : FunctionPass(ID), TrackOrigins(TrackOrigins), Recover(Recover), - EnableKmsan(EnableKmsan) {} + MemorySanitizerLegacyPass(MemorySanitizerOptions Options = {}) + : FunctionPass(ID), Options(Options) {} StringRef getPassName() const override { return "MemorySanitizerLegacyPass"; } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -612,16 +620,14 @@ struct MemorySanitizerLegacyPass : public FunctionPass { bool doInitialization(Module &M) override; Optional<MemorySanitizer> MSan; - int TrackOrigins; - bool Recover; - bool EnableKmsan; + MemorySanitizerOptions Options; }; } // end anonymous namespace PreservedAnalyses MemorySanitizerPass::run(Function &F, FunctionAnalysisManager &FAM) { - MemorySanitizer Msan(*F.getParent(), TrackOrigins, Recover, EnableKmsan); + MemorySanitizer Msan(*F.getParent(), Options); if (Msan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F))) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -637,10 +643,9 @@ INITIALIZE_PASS_END(MemorySanitizerLegacyPass, "msan", "MemorySanitizer: detects uninitialized reads.", false, false) -FunctionPass *llvm::createMemorySanitizerLegacyPassPass(int TrackOrigins, - bool Recover, - bool CompileKernel) { - return new MemorySanitizerLegacyPass(TrackOrigins, Recover, CompileKernel); +FunctionPass * +llvm::createMemorySanitizerLegacyPassPass(MemorySanitizerOptions Options) { + return new MemorySanitizerLegacyPass(Options); } /// Create a non-const global initialized with the given string. @@ -675,18 +680,15 @@ void MemorySanitizer::createKernelApi(Module &M) { IRB.getInt32Ty()); // Requests the per-task context state (kmsan_context_state*) from the // runtime library. + MsanContextStateTy = StructType::get( + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), /* va_arg_origin */ + IRB.getInt64Ty(), ArrayType::get(OriginTy, kParamTLSSize / 4), OriginTy, + OriginTy); MsanGetContextStateFn = M.getOrInsertFunction( - "__msan_get_context_state", - PointerType::get( - StructType::get(ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), - ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), - ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), - ArrayType::get(IRB.getInt64Ty(), - kParamTLSSize / 8), /* va_arg_origin */ - IRB.getInt64Ty(), - ArrayType::get(OriginTy, kParamTLSSize / 4), OriginTy, - OriginTy), - 0)); + "__msan_get_context_state", PointerType::get(MsanContextStateTy, 0)); Type *RetTy = StructType::get(PointerType::get(IRB.getInt8Ty(), 0), PointerType::get(IRB.getInt32Ty(), 0)); @@ -821,8 +823,9 @@ void MemorySanitizer::initializeCallbacks(Module &M) { CallbacksInitialized = true; } -Value *MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore, int size) { - Value **Fns = +FunctionCallee MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore, + int size) { + FunctionCallee *Fns = isStore ? MsanMetadataPtrForStore_1_8 : MsanMetadataPtrForLoad_1_8; switch (size) { case 1: @@ -839,6 +842,8 @@ Value *MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore, int size) { } /// Module-level initialization. +/// +/// inserts a call to __msan_init to the module's constructor list. void MemorySanitizer::initializeModule(Module &M) { auto &DL = M.getDataLayout(); @@ -913,7 +918,22 @@ void MemorySanitizer::initializeModule(Module &M) { OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000); if (!CompileKernel) { - getOrCreateInitFunction(M, kMsanInitName); + std::tie(MsanCtorFunction, std::ignore) = + getOrCreateSanitizerCtorAndInitFunctions( + M, kMsanModuleCtorName, kMsanInitName, + /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, FunctionCallee) { + if (!ClWithComdat) { + appendToGlobalCtors(M, Ctor, 0); + return; + } + Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName); + Ctor->setComdat(MsanCtorComdat); + appendToGlobalCtors(M, Ctor, 0, Ctor); + }); if (TrackOrigins) M.getOrInsertGlobal("__msan_track_origins", IRB.getInt32Ty(), [&] { @@ -932,7 +952,7 @@ void MemorySanitizer::initializeModule(Module &M) { } bool MemorySanitizerLegacyPass::doInitialization(Module &M) { - MSan.emplace(M, TrackOrigins, Recover, EnableKmsan); + MSan.emplace(M, Options); return true; } @@ -1011,6 +1031,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { : Shadow(S), Origin(O), OrigIns(I) {} }; SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; + bool InstrumentLifetimeStart = ClHandleLifetimeIntrinsics; + SmallSet<AllocaInst *, 16> AllocaSet; + SmallVector<std::pair<IntrinsicInst *, AllocaInst *>, 16> LifetimeStartList; SmallVector<StoreInst *, 16> StoreList; MemorySanitizerVisitor(Function &F, MemorySanitizer &MS, @@ -1076,7 +1099,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (unsigned i = Ofs; i < (Size + kOriginSize - 1) / kOriginSize; ++i) { Value *GEP = - i ? IRB.CreateConstGEP1_32(nullptr, OriginPtr, i) : OriginPtr; + i ? IRB.CreateConstGEP1_32(MS.OriginTy, OriginPtr, i) : OriginPtr; IRB.CreateAlignedStore(Origin, GEP, CurrentAlignment); CurrentAlignment = kMinOriginAlignment; } @@ -1104,7 +1127,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { - Value *Fn = MS.MaybeStoreOriginFn[SizeIndex]; + FunctionCallee Fn = MS.MaybeStoreOriginFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt( ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); IRB.CreateCall(Fn, {ConvertedShadow2, @@ -1186,7 +1209,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { - Value *Fn = MS.MaybeWarningFn[SizeIndex]; + FunctionCallee Fn = MS.MaybeWarningFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); IRB.CreateCall(Fn, {ConvertedShadow2, MS.TrackOrigins && Origin @@ -1221,20 +1244,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); Value *ContextState = IRB.CreateCall(MS.MsanGetContextStateFn, {}); Constant *Zero = IRB.getInt32(0); - MS.ParamTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(0)}, "param_shadow"); - MS.RetvalTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(1)}, "retval_shadow"); - MS.VAArgTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(2)}, "va_arg_shadow"); - MS.VAArgOriginTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(3)}, "va_arg_origin"); - MS.VAArgOverflowSizeTLS = IRB.CreateGEP( - ContextState, {Zero, IRB.getInt32(4)}, "va_arg_overflow_size"); - MS.ParamOriginTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(5)}, "param_origin"); + MS.ParamTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(0)}, "param_shadow"); + MS.RetvalTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(1)}, "retval_shadow"); + MS.VAArgTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(2)}, "va_arg_shadow"); + MS.VAArgOriginTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(3)}, "va_arg_origin"); + MS.VAArgOverflowSizeTLS = + IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(4)}, "va_arg_overflow_size"); + MS.ParamOriginTLS = IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(5)}, "param_origin"); MS.RetvalOriginTLS = - IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(6)}, "retval_origin"); + IRB.CreateGEP(MS.MsanContextStateTy, ContextState, + {Zero, IRB.getInt32(6)}, "retval_origin"); return ret; } @@ -1265,6 +1290,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { VAHelper->finalizeInstrumentation(); + // Poison llvm.lifetime.start intrinsics, if we haven't fallen back to + // instrumenting only allocas. + if (InstrumentLifetimeStart) { + for (auto Item : LifetimeStartList) { + instrumentAlloca(*Item.second, Item.first); + AllocaSet.erase(Item.second); + } + } + // Poison the allocas for which we didn't instrument the corresponding + // lifetime intrinsics. + for (AllocaInst *AI : AllocaSet) + instrumentAlloca(*AI); + bool InstrumentWithCalls = ClInstrumentationWithCallThreshold >= 0 && InstrumentationList.size() + StoreList.size() > (unsigned)ClInstrumentationWithCallThreshold; @@ -1381,7 +1419,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRB.CreateAnd(OriginLong, ConstantInt::get(MS.IntptrTy, ~Mask)); } OriginPtr = - IRB.CreateIntToPtr(OriginLong, PointerType::get(IRB.getInt32Ty(), 0)); + IRB.CreateIntToPtr(OriginLong, PointerType::get(MS.OriginTy, 0)); } return std::make_pair(ShadowPtr, OriginPtr); } @@ -1393,7 +1431,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { const DataLayout &DL = F.getParent()->getDataLayout(); int Size = DL.getTypeStoreSize(ShadowTy); - Value *Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size); + FunctionCallee Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size); Value *AddrCast = IRB.CreatePointerCast(Addr, PointerType::get(IRB.getInt8Ty(), 0)); if (Getter) { @@ -1598,8 +1636,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // ParamTLS overflow. *ShadowPtr = getCleanShadow(V); } else { - *ShadowPtr = - EntryIRB.CreateAlignedLoad(Base, kShadowTLSAlignment); + *ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base, + kShadowTLSAlignment); } } LLVM_DEBUG(dbgs() @@ -1607,7 +1645,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins && !Overflow) { Value *OriginPtr = getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); - setOrigin(A, EntryIRB.CreateLoad(OriginPtr)); + setOrigin(A, EntryIRB.CreateLoad(MS.OriginTy, OriginPtr)); } else { setOrigin(A, getCleanOrigin()); } @@ -1738,7 +1776,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (PropagateShadow) { std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); - setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); + setShadow(&I, + IRB.CreateAlignedLoad(ShadowTy, ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); } @@ -1752,7 +1791,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) { unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); - setOrigin(&I, IRB.CreateAlignedLoad(OriginPtr, OriginAlignment)); + setOrigin( + &I, IRB.CreateAlignedLoad(MS.OriginTy, OriginPtr, OriginAlignment)); } else { setOrigin(&I, getCleanOrigin()); } @@ -1903,7 +1943,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *S1S2 = IRB.CreateAnd(S1, S2); Value *V1S2 = IRB.CreateAnd(V1, S2); Value *S1V2 = IRB.CreateAnd(S1, V2); - setShadow(&I, IRB.CreateOr(S1S2, IRB.CreateOr(V1S2, S1V2))); + setShadow(&I, IRB.CreateOr({S1S2, V1S2, S1V2})); setOriginForNaryOp(I); } @@ -1925,7 +1965,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *S1S2 = IRB.CreateAnd(S1, S2); Value *V1S2 = IRB.CreateAnd(V1, S2); Value *S1V2 = IRB.CreateAnd(S1, V2); - setShadow(&I, IRB.CreateOr(S1S2, IRB.CreateOr(V1S2, S1V2))); + setShadow(&I, IRB.CreateOr({S1S2, V1S2, S1V2})); setOriginForNaryOp(I); } @@ -2070,6 +2110,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { SC.Done(&I); } + void visitFNeg(UnaryOperator &I) { handleShadowOr(I); } + // Handle multiplication by constant. // // Handle a special case of multiplication by constant that may have one or @@ -2432,7 +2474,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned Alignment = 1; std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); - setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); + setShadow(&I, + IRB.CreateAlignedLoad(ShadowTy, ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); } @@ -2442,7 +2485,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) - setOrigin(&I, IRB.CreateLoad(OriginPtr)); + setOrigin(&I, IRB.CreateLoad(MS.OriginTy, OriginPtr)); else setOrigin(&I, getCleanOrigin()); } @@ -2519,6 +2562,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return false; } + void handleLifetimeStart(IntrinsicInst &I) { + if (!PoisonStack) + return; + DenseMap<Value *, AllocaInst *> AllocaForValue; + AllocaInst *AI = + llvm::findAllocaForValue(I.getArgOperand(1), AllocaForValue); + if (!AI) + InstrumentLifetimeStart = false; + LifetimeStartList.push_back(std::make_pair(&I, AI)); + } + void handleBswap(IntrinsicInst &I) { IRBuilder<> IRB(&I); Value *Op = I.getArgOperand(0); @@ -2650,7 +2704,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { : Lower64ShadowExtend(IRB, S2, getShadowTy(&I)); Value *V1 = I.getOperand(0); Value *V2 = I.getOperand(1); - Value *Shift = IRB.CreateCall(I.getCalledValue(), + Value *Shift = IRB.CreateCall(I.getFunctionType(), I.getCalledValue(), {IRB.CreateBitCast(S1, V1->getType()), V2}); Shift = IRB.CreateBitCast(Shift, getShadowTy(&I)); setShadow(&I, IRB.CreateOr(Shift, S2Conv)); @@ -2660,6 +2714,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Get an X86_MMX-sized vector type. Type *getMMXVectorTy(unsigned EltSizeInBits) { const unsigned X86_MMXSizeInBits = 64; + assert(EltSizeInBits != 0 && (X86_MMXSizeInBits % EltSizeInBits) == 0 && + "Illegal MMX vector element size"); return VectorType::get(IntegerType::get(*MS.C, EltSizeInBits), X86_MMXSizeInBits / EltSizeInBits); } @@ -2825,9 +2881,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); - Value *Shadow = IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_ldmxcsr"); - Value *Origin = - MS.TrackOrigins ? IRB.CreateLoad(OriginPtr) : getCleanOrigin(); + Value *Shadow = IRB.CreateAlignedLoad(Ty, ShadowPtr, Alignment, "_ldmxcsr"); + Value *Origin = MS.TrackOrigins ? IRB.CreateLoad(MS.OriginTy, OriginPtr) + : getCleanOrigin(); insertShadowCheck(Shadow, Origin, &I); } @@ -2901,7 +2957,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Origin = IRB.CreateSelect( IRB.CreateICmpNE(Acc, Constant::getNullValue(Acc->getType())), - getOrigin(PassThru), IRB.CreateLoad(OriginPtr)); + getOrigin(PassThru), IRB.CreateLoad(MS.OriginTy, OriginPtr)); setOrigin(&I, Origin); } else { @@ -2911,9 +2967,32 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return true; } + // Instrument BMI / BMI2 intrinsics. + // All of these intrinsics are Z = I(X, Y) + // where the types of all operands and the result match, and are either i32 or i64. + // The following instrumentation happens to work for all of them: + // Sz = I(Sx, Y) | (sext (Sy != 0)) + void handleBmiIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Type *ShadowTy = getShadowTy(&I); + + // If any bit of the mask operand is poisoned, then the whole thing is. + Value *SMask = getShadow(&I, 1); + SMask = IRB.CreateSExt(IRB.CreateICmpNE(SMask, getCleanShadow(ShadowTy)), + ShadowTy); + // Apply the same intrinsic to the shadow of the first operand. + Value *S = IRB.CreateCall(I.getCalledFunction(), + {getShadow(&I, 0), I.getOperand(1)}); + S = IRB.CreateOr(SMask, S); + setShadow(&I, S); + setOriginForNaryOp(I); + } void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { + case Intrinsic::lifetime_start: + handleLifetimeStart(I); + break; case Intrinsic::bswap: handleBswap(I); break; @@ -3127,6 +3206,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleVectorComparePackedIntrinsic(I); break; + case Intrinsic::x86_bmi_bextr_32: + case Intrinsic::x86_bmi_bextr_64: + case Intrinsic::x86_bmi_bzhi_32: + case Intrinsic::x86_bmi_bzhi_64: + case Intrinsic::x86_bmi_pdep_32: + case Intrinsic::x86_bmi_pdep_64: + case Intrinsic::x86_bmi_pext_32: + case Intrinsic::x86_bmi_pext_64: + handleBmiIntrinsic(I); + break; + case Intrinsic::is_constant: // The result of llvm.is.constant() is always defined. setShadow(&I, getCleanShadow(&I)); @@ -3143,21 +3233,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitCallSite(CallSite CS) { Instruction &I = *CS.getInstruction(); assert(!I.getMetadata("nosanitize")); - assert((CS.isCall() || CS.isInvoke()) && "Unknown type of CallSite"); + assert((CS.isCall() || CS.isInvoke() || CS.isCallBr()) && + "Unknown type of CallSite"); + if (CS.isCallBr() || (CS.isCall() && cast<CallInst>(&I)->isInlineAsm())) { + // For inline asm (either a call to asm function, or callbr instruction), + // do the usual thing: check argument shadow and mark all outputs as + // clean. Note that any side effects of the inline asm that are not + // immediately visible in its constraints are not handled. + if (ClHandleAsmConservative && MS.CompileKernel) + visitAsmInstruction(I); + else + visitInstruction(I); + return; + } if (CS.isCall()) { CallInst *Call = cast<CallInst>(&I); - - // For inline asm, do the usual thing: check argument shadow and mark all - // outputs as clean. Note that any side effects of the inline asm that are - // not immediately visible in its constraints are not handled. - if (Call->isInlineAsm()) { - if (ClHandleAsmConservative && MS.CompileKernel) - visitAsmInstruction(I); - else - visitInstruction(I); - return; - } - assert(!isa<IntrinsicInst>(&I) && "intrinsics are handled elsewhere"); // We are going to insert code that relies on the fact that the callee @@ -3264,12 +3354,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { "Could not find insertion point for retval shadow load"); } IRBuilder<> IRBAfter(&*NextInsn); - Value *RetvalShadow = - IRBAfter.CreateAlignedLoad(getShadowPtrForRetval(&I, IRBAfter), - kShadowTLSAlignment, "_msret"); + Value *RetvalShadow = IRBAfter.CreateAlignedLoad( + getShadowTy(&I), getShadowPtrForRetval(&I, IRBAfter), + kShadowTLSAlignment, "_msret"); setShadow(&I, RetvalShadow); if (MS.TrackOrigins) - setOrigin(&I, IRBAfter.CreateLoad(getOriginPtrForRetval(IRBAfter))); + setOrigin(&I, IRBAfter.CreateLoad(MS.OriginTy, + getOriginPtrForRetval(IRBAfter))); } bool isAMustTailRetVal(Value *RetVal) { @@ -3330,7 +3421,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { StackDescription.str()); } - void instrumentAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { + void poisonAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { if (PoisonStack && ClPoisonStackWithCall) { IRB.CreateCall(MS.MsanPoisonStackFn, {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); @@ -3352,7 +3443,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - void instrumentAllocaKmsan(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { + void poisonAllocaKmsan(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { Value *Descr = getLocalVarDescription(I); if (PoisonStack) { IRB.CreateCall(MS.MsanPoisonAllocaFn, @@ -3364,10 +3455,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - void visitAllocaInst(AllocaInst &I) { - setShadow(&I, getCleanShadow(&I)); - setOrigin(&I, getCleanOrigin()); - IRBuilder<> IRB(I.getNextNode()); + void instrumentAlloca(AllocaInst &I, Instruction *InsPoint = nullptr) { + if (!InsPoint) + InsPoint = &I; + IRBuilder<> IRB(InsPoint->getNextNode()); const DataLayout &DL = F.getParent()->getDataLayout(); uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); @@ -3375,9 +3466,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Len = IRB.CreateMul(Len, I.getArraySize()); if (MS.CompileKernel) - instrumentAllocaKmsan(I, IRB, Len); + poisonAllocaKmsan(I, IRB, Len); else - instrumentAllocaUserspace(I, IRB, Len); + poisonAllocaUserspace(I, IRB, Len); + } + + void visitAllocaInst(AllocaInst &I) { + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + // We'll get to this alloca later unless it's poisoned at the corresponding + // llvm.lifetime.start. + AllocaSet.insert(&I); } void visitSelectInst(SelectInst& I) { @@ -3409,7 +3508,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { D = CreateAppToShadowCast(IRB, D); // Result shadow if condition shadow is 1. - Sa1 = IRB.CreateOr(IRB.CreateXor(C, D), IRB.CreateOr(Sc, Sd)); + Sa1 = IRB.CreateOr({IRB.CreateXor(C, D), Sc, Sd}); } Value *Sa = IRB.CreateSelect(Sb, Sa1, Sa0, "_msprop_select"); setShadow(&I, Sa); @@ -3525,10 +3624,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } /// Get the number of output arguments returned by pointers. - int getNumOutputArgs(InlineAsm *IA, CallInst *CI) { + int getNumOutputArgs(InlineAsm *IA, CallBase *CB) { int NumRetOutputs = 0; int NumOutputs = 0; - Type *RetTy = dyn_cast<Value>(CI)->getType(); + Type *RetTy = dyn_cast<Value>(CB)->getType(); if (!RetTy->isVoidTy()) { // Register outputs are returned via the CallInst return value. StructType *ST = dyn_cast_or_null<StructType>(RetTy); @@ -3568,24 +3667,24 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // corresponding CallInst has nO+nI+1 operands (the last operand is the // function to be called). const DataLayout &DL = F.getParent()->getDataLayout(); - CallInst *CI = dyn_cast<CallInst>(&I); + CallBase *CB = dyn_cast<CallBase>(&I); IRBuilder<> IRB(&I); - InlineAsm *IA = cast<InlineAsm>(CI->getCalledValue()); - int OutputArgs = getNumOutputArgs(IA, CI); + InlineAsm *IA = cast<InlineAsm>(CB->getCalledValue()); + int OutputArgs = getNumOutputArgs(IA, CB); // The last operand of a CallInst is the function itself. - int NumOperands = CI->getNumOperands() - 1; + int NumOperands = CB->getNumOperands() - 1; // Check input arguments. Doing so before unpoisoning output arguments, so // that we won't overwrite uninit values before checking them. for (int i = OutputArgs; i < NumOperands; i++) { - Value *Operand = CI->getOperand(i); + Value *Operand = CB->getOperand(i); instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ false); } // Unpoison output arguments. This must happen before the actual InlineAsm // call, so that the shadow for memory published in the asm() statement // remains valid. for (int i = 0; i < OutputArgs; i++) { - Value *Operand = CI->getOperand(i); + Value *Operand = CB->getOperand(i); instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ true); } @@ -3817,7 +3916,8 @@ struct VarArgAMD64Helper : public VarArgHelper { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); - VAArgOverflowSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); + VAArgOverflowSize = + IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), VAArgOverflowSize); @@ -3836,11 +3936,13 @@ struct VarArgAMD64Helper : public VarArgHelper { IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); + Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 16)), - PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); - Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); + PointerType::get(RegSaveAreaPtrTy, 0)); + Value *RegSaveAreaPtr = + IRB.CreateLoad(RegSaveAreaPtrTy, RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 16; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = @@ -3851,11 +3953,13 @@ struct VarArgAMD64Helper : public VarArgHelper { if (MS.TrackOrigins) IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy, Alignment, AMD64FpEndOffset); + Type *OverflowArgAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 8)), - PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); - Value *OverflowArgAreaPtr = IRB.CreateLoad(OverflowArgAreaPtrPtr); + PointerType::get(OverflowArgAreaPtrTy, 0)); + Value *OverflowArgAreaPtr = + IRB.CreateLoad(OverflowArgAreaPtrTy, OverflowArgAreaPtrPtr); Value *OverflowArgAreaShadowPtr, *OverflowArgAreaOriginPtr; std::tie(OverflowArgAreaShadowPtr, OverflowArgAreaOriginPtr) = MSV.getShadowOriginPtr(OverflowArgAreaPtr, IRB, IRB.getInt8Ty(), @@ -3957,7 +4061,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); - VAArgSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); + VAArgSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); @@ -3974,10 +4078,12 @@ struct VarArgMIPS64Helper : public VarArgHelper { CallInst *OrigInst = VAStartInstrumentationList[i]; IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); + Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), - PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); - Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); + PointerType::get(RegSaveAreaPtrTy, 0)); + Value *RegSaveAreaPtr = + IRB.CreateLoad(RegSaveAreaPtrTy, RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 8; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = @@ -4127,7 +4233,7 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), Type::getInt64PtrTy(*MS.C)); - return IRB.CreateLoad(SaveAreaPtrPtr); + return IRB.CreateLoad(Type::getInt64Ty(*MS.C), SaveAreaPtrPtr); } // Retrieve a va_list field of 'int' size. @@ -4137,7 +4243,7 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, offset)), Type::getInt32PtrTy(*MS.C)); - Value *SaveArea32 = IRB.CreateLoad(SaveAreaPtr); + Value *SaveArea32 = IRB.CreateLoad(IRB.getInt32Ty(), SaveAreaPtr); return IRB.CreateSExt(SaveArea32, MS.IntptrTy); } @@ -4148,7 +4254,8 @@ struct VarArgAArch64Helper : public VarArgHelper { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); - VAArgOverflowSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); + VAArgOverflowSize = + IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), VAArgOverflowSize); @@ -4391,7 +4498,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); - VAArgSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); + VAArgSize = IRB.CreateLoad(IRB.getInt64Ty(), MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); @@ -4408,10 +4515,12 @@ struct VarArgPowerPC64Helper : public VarArgHelper { CallInst *OrigInst = VAStartInstrumentationList[i]; IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); + Type *RegSaveAreaPtrTy = Type::getInt64PtrTy(*MS.C); Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), - PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); - Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); + PointerType::get(RegSaveAreaPtrTy, 0)); + Value *RegSaveAreaPtr = + IRB.CreateLoad(RegSaveAreaPtrTy, RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 8; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = @@ -4458,6 +4567,8 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, } bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { + if (!CompileKernel && (&F == MsanCtorFunction)) + return false; MemorySanitizerVisitor Visitor(F, *this, TLI); // Clear out readonly/readnone attributes. diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index f043325f5bba..6fec3c9c79ee 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -1,9 +1,8 @@ //===- PGOInstrumentation.cpp - MST-based PGO Instrumentation -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -48,7 +47,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "CFGMST.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -66,6 +64,7 @@ #include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -107,6 +106,7 @@ #include "llvm/Support/JamCRC.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> #include <cassert> @@ -133,6 +133,19 @@ STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts."); STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile."); STATISTIC(NumOfPGOMissing, "Number of functions without profile."); STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations."); +STATISTIC(NumOfCSPGOInstrument, "Number of edges instrumented in CSPGO."); +STATISTIC(NumOfCSPGOSelectInsts, + "Number of select instruction instrumented in CSPGO."); +STATISTIC(NumOfCSPGOMemIntrinsics, + "Number of mem intrinsics instrumented in CSPGO."); +STATISTIC(NumOfCSPGOEdge, "Number of edges in CSPGO."); +STATISTIC(NumOfCSPGOBB, "Number of basic-blocks in CSPGO."); +STATISTIC(NumOfCSPGOSplit, "Number of critical edge splits in CSPGO."); +STATISTIC(NumOfCSPGOFunc, + "Number of functions having valid profile counts in CSPGO."); +STATISTIC(NumOfCSPGOMismatch, + "Number of functions having mismatch profile in CSPGO."); +STATISTIC(NumOfCSPGOMissing, "Number of functions without profile in CSPGO."); // Command line option to specify the file to read profile from. This is // mainly used for testing. @@ -384,7 +397,8 @@ class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; - PGOInstrumentationGenLegacyPass() : ModulePass(ID) { + PGOInstrumentationGenLegacyPass(bool IsCS = false) + : ModulePass(ID), IsCS(IsCS) { initializePGOInstrumentationGenLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -392,6 +406,8 @@ public: StringRef getPassName() const override { return "PGOInstrumentationGenPass"; } private: + // Is this is context-sensitive instrumentation. + bool IsCS; bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -404,8 +420,8 @@ public: static char ID; // Provide the profile filename as the parameter. - PGOInstrumentationUseLegacyPass(std::string Filename = "") - : ModulePass(ID), ProfileFileName(std::move(Filename)) { + PGOInstrumentationUseLegacyPass(std::string Filename = "", bool IsCS = false) + : ModulePass(ID), ProfileFileName(std::move(Filename)), IsCS(IsCS) { if (!PGOTestProfileFile.empty()) ProfileFileName = PGOTestProfileFile; initializePGOInstrumentationUseLegacyPassPass( @@ -416,14 +432,38 @@ public: private: std::string ProfileFileName; + // Is this is context-sensitive instrumentation use. + bool IsCS; bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; +class PGOInstrumentationGenCreateVarLegacyPass : public ModulePass { +public: + static char ID; + StringRef getPassName() const override { + return "PGOInstrumentationGenCreateVarPass"; + } + PGOInstrumentationGenCreateVarLegacyPass(std::string CSInstrName = "") + : ModulePass(ID), InstrProfileOutput(CSInstrName) { + initializePGOInstrumentationGenCreateVarLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + +private: + bool runOnModule(Module &M) override { + createProfileFileNameVar(M, InstrProfileOutput); + createIRLevelProfileFlagVar(M, true); + return false; + } + std::string InstrProfileOutput; +}; + } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; @@ -435,8 +475,8 @@ INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", "PGO instrumentation.", false, false) -ModulePass *llvm::createPGOInstrumentationGenLegacyPass() { - return new PGOInstrumentationGenLegacyPass(); +ModulePass *llvm::createPGOInstrumentationGenLegacyPass(bool IsCS) { + return new PGOInstrumentationGenLegacyPass(IsCS); } char PGOInstrumentationUseLegacyPass::ID = 0; @@ -445,11 +485,25 @@ INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use", "Read PGO instrumentation profile.", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use", "Read PGO instrumentation profile.", false, false) -ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) { - return new PGOInstrumentationUseLegacyPass(Filename.str()); +ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename, + bool IsCS) { + return new PGOInstrumentationUseLegacyPass(Filename.str(), IsCS); +} + +char PGOInstrumentationGenCreateVarLegacyPass::ID = 0; + +INITIALIZE_PASS(PGOInstrumentationGenCreateVarLegacyPass, + "pgo-instr-gen-create-var", + "Create PGO instrumentation version variable for CSPGO.", false, + false) + +ModulePass * +llvm::createPGOInstrumentationGenCreateVarLegacyPass(StringRef CSInstrName) { + return new PGOInstrumentationGenCreateVarLegacyPass(CSInstrName); } namespace { @@ -490,6 +544,12 @@ struct BBInfo { const std::string infoString() const { return (Twine("Index=") + Twine(Index)).str(); } + + // Empty function -- only applicable to UseBBInfo. + void addOutEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {} + + // Empty function -- only applicable to UseBBInfo. + void addInEdge(PGOEdge *E LLVM_ATTRIBUTE_UNUSED) {} }; // This class implements the CFG edges. Note the CFG can be a multi-graph. @@ -497,6 +557,9 @@ template <class Edge, class BBInfo> class FuncPGOInstrumentation { private: Function &F; + // Is this is context-sensitive instrumentation. + bool IsCS; + // A map that stores the Comdat group in function F. std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; @@ -516,6 +579,10 @@ public: // The Minimum Spanning Tree of function CFG. CFGMST<Edge, BBInfo> MST; + // Collect all the BBs that will be instrumented, and store them in + // InstrumentBBs. + void getInstrumentBBs(std::vector<BasicBlock *> &InstrumentBBs); + // Give an edge, find the BB that will be instrumented. // Return nullptr if there is no BB to be instrumented. BasicBlock *getInstrBB(Edge *E); @@ -536,15 +603,23 @@ public: Function &Func, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFI = nullptr) - : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1), - SIVisitor(Func), MIVisitor(Func), MST(F, BPI, BFI) { + BlockFrequencyInfo *BFI = nullptr, bool IsCS = false) + : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers), + ValueSites(IPVK_Last + 1), SIVisitor(Func), MIVisitor(Func), + MST(F, BPI, BFI) { // This should be done before CFG hash computation. SIVisitor.countSelects(Func); MIVisitor.countMemIntrinsics(Func); - NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); - NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); - ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func); + if (!IsCS) { + NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); + NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); + NumOfPGOBB += MST.BBInfos.size(); + ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func); + } else { + NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); + NumOfCSPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); + NumOfCSPGOBB += MST.BBInfos.size(); + } ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func); FuncName = getPGOFuncName(F); @@ -553,28 +628,17 @@ public: renameComdatFunction(); LLVM_DEBUG(dumpInfo("after CFGMST")); - NumOfPGOBB += MST.BBInfos.size(); for (auto &E : MST.AllEdges) { if (E->Removed) continue; - NumOfPGOEdge++; + IsCS ? NumOfCSPGOEdge++ : NumOfPGOEdge++; if (!E->InMST) - NumOfPGOInstrument++; + IsCS ? NumOfCSPGOInstrument++ : NumOfPGOInstrument++; } if (CreateGlobalVar) FuncNameVar = createPGOFuncNameVar(F, FuncName); } - - // Return the number of profile counters needed for the function. - unsigned getNumCounters() { - unsigned NumCounters = 0; - for (auto &E : this->MST.AllEdges) { - if (!E->InMST && !E->Removed) - NumCounters++; - } - return NumCounters + SIVisitor.getNumOfSelectInsts(); - } }; } // end anonymous namespace @@ -598,9 +662,17 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { } } JC.update(Indexes); + + // Hash format for context sensitive profile. Reserve 4 bits for other + // information. FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 | + //(uint64_t)ValueSites[IPVK_MemOPSize].size() << 40 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); + // Reserve bit 60-63 for other information purpose. + FunctionHash &= 0x0FFFFFFFFFFFFFFF; + if (IsCS) + NamedInstrProfRecord::setCSFlagInHash(FunctionHash); LLVM_DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n" << " CRC = " << JC.getCRC() << ", Selects = " << SIVisitor.getNumOfSelectInsts() @@ -681,6 +753,36 @@ void FuncPGOInstrumentation<Edge, BBInfo>::renameComdatFunction() { } } +// Collect all the BBs that will be instruments and return them in +// InstrumentBBs and setup InEdges/OutEdge for UseBBInfo. +template <class Edge, class BBInfo> +void FuncPGOInstrumentation<Edge, BBInfo>::getInstrumentBBs( + std::vector<BasicBlock *> &InstrumentBBs) { + // Use a worklist as we will update the vector during the iteration. + std::vector<Edge *> EdgeList; + EdgeList.reserve(MST.AllEdges.size()); + for (auto &E : MST.AllEdges) + EdgeList.push_back(E.get()); + + for (auto &E : EdgeList) { + BasicBlock *InstrBB = getInstrBB(E); + if (InstrBB) + InstrumentBBs.push_back(InstrBB); + } + + // Set up InEdges/OutEdges for all BBs. + for (auto &E : MST.AllEdges) { + if (E->Removed) + continue; + const BasicBlock *SrcBB = E->SrcBB; + const BasicBlock *DestBB = E->DestBB; + BBInfo &SrcInfo = getBBInfo(SrcBB); + BBInfo &DestInfo = getBBInfo(DestBB); + SrcInfo.addOutEdge(E.get()); + DestInfo.addInEdge(E.get()); + } +} + // Given a CFG E to be instrumented, find which BB to place the instrumented // code. The function will split the critical edge if necessary. template <class Edge, class BBInfo> @@ -696,46 +798,64 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { if (DestBB == nullptr) return SrcBB; + auto canInstrument = [](BasicBlock *BB) -> BasicBlock * { + // There are basic blocks (such as catchswitch) cannot be instrumented. + // If the returned first insertion point is the end of BB, skip this BB. + if (BB->getFirstInsertionPt() == BB->end()) + return nullptr; + return BB; + }; + // Instrument the SrcBB if it has a single successor, // otherwise, the DestBB if this is not a critical edge. Instruction *TI = SrcBB->getTerminator(); if (TI->getNumSuccessors() <= 1) - return SrcBB; + return canInstrument(SrcBB); if (!E->IsCritical) - return DestBB; + return canInstrument(DestBB); + unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB); + BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum); + if (!InstrBB) { + LLVM_DEBUG( + dbgs() << "Fail to split critical edge: not instrument this edge.\n"); + return nullptr; + } // For a critical edge, we have to split. Instrument the newly // created BB. - NumOfPGOSplit++; + IsCS ? NumOfCSPGOSplit++ : NumOfPGOSplit++; LLVM_DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> " << getBBInfo(DestBB).Index << "\n"); - unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB); - BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum); - assert(InstrBB && "Critical edge is not split"); - + // Need to add two new edges. First one: Add new edge of SrcBB->InstrBB. + MST.addEdge(SrcBB, InstrBB, 0); + // Second one: Add new edge of InstrBB->DestBB. + Edge &NewEdge1 = MST.addEdge(InstrBB, DestBB, 0); + NewEdge1.InMST = true; E->Removed = true; - return InstrBB; + + return canInstrument(InstrBB); } // Visit all edge and instrument the edges not in MST, and do value profiling. // Critical edges will be split. static void instrumentOneFunc( Function &F, Module *M, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI, - std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, + bool IsCS) { // Split indirectbr critical edges here before computing the MST rather than // later in getInstrBB() to avoid invalidating it. SplitIndirectBrCriticalEdges(F, BPI, BFI); + FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, ComdatMembers, true, BPI, - BFI); - unsigned NumCounters = FuncInfo.getNumCounters(); + BFI, IsCS); + std::vector<BasicBlock *> InstrumentBBs; + FuncInfo.getInstrumentBBs(InstrumentBBs); + unsigned NumCounters = + InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts(); uint32_t I = 0; Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); - for (auto &E : FuncInfo.MST.AllEdges) { - BasicBlock *InstrBB = FuncInfo.getInstrBB(E.get()); - if (!InstrBB) - continue; - + for (auto *InstrBB : InstrumentBBs) { IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt()); assert(Builder.GetInsertPoint() != InstrBB->end() && "Cannot get the Instrumentation point"); @@ -831,6 +951,18 @@ struct UseBBInfo : public BBInfo { return BBInfo::infoString(); return (Twine(BBInfo::infoString()) + " Count=" + Twine(CountValue)).str(); } + + // Add an OutEdge and update the edge count. + void addOutEdge(PGOUseEdge *E) { + OutEdges.push_back(E); + UnknownCountOutEdge++; + } + + // Add an InEdge and update the edge count. + void addInEdge(PGOUseEdge *E) { + InEdges.push_back(E); + UnknownCountInEdge++; + } }; } // end anonymous namespace @@ -853,10 +985,10 @@ public: PGOUseFunc(Function &Func, Module *Modu, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFIin = nullptr) + BlockFrequencyInfo *BFIin = nullptr, bool IsCS = false) : F(Func), M(Modu), BFI(BFIin), - FuncInfo(Func, ComdatMembers, false, BPI, BFIin), - FreqAttr(FFA_Normal) {} + FuncInfo(Func, ComdatMembers, false, BPI, BFIin, IsCS), + FreqAttr(FFA_Normal), IsCS(IsCS) {} // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros); @@ -929,8 +1061,11 @@ private: // Function hotness info derived from profile. FuncFreqAttr FreqAttr; - // Find the Instrumented BB and set the value. - void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile); + // Is to use the context sensitive profile. + bool IsCS; + + // Find the Instrumented BB and set the value. Return false on error. + bool setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile); // Set the edge counter value for the unknown edge -- there should be only // one unknown edge. @@ -959,41 +1094,64 @@ private: } // end anonymous namespace // Visit all the edges and assign the count value for the instrumented -// edges and the BB. -void PGOUseFunc::setInstrumentedCounts( +// edges and the BB. Return false on error. +bool PGOUseFunc::setInstrumentedCounts( const std::vector<uint64_t> &CountFromProfile) { - assert(FuncInfo.getNumCounters() == CountFromProfile.size()); - // Use a worklist as we will update the vector during the iteration. - std::vector<PGOUseEdge *> WorkList; - for (auto &E : FuncInfo.MST.AllEdges) - WorkList.push_back(E.get()); + std::vector<BasicBlock *> InstrumentBBs; + FuncInfo.getInstrumentBBs(InstrumentBBs); + unsigned NumCounters = + InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts(); + // The number of counters here should match the number of counters + // in profile. Return if they mismatch. + if (NumCounters != CountFromProfile.size()) { + return false; + } + // Set the profile count to the Instrumented BBs. uint32_t I = 0; - for (auto &E : WorkList) { - BasicBlock *InstrBB = FuncInfo.getInstrBB(E); - if (!InstrBB) - continue; + for (BasicBlock *InstrBB : InstrumentBBs) { uint64_t CountValue = CountFromProfile[I++]; - if (!E->Removed) { - getBBInfo(InstrBB).setBBInfoCount(CountValue); - E->setEdgeCount(CountValue); - continue; - } - - // Need to add two new edges. - BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB); - BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB); - // Add new edge of SrcBB->InstrBB. - PGOUseEdge &NewEdge = FuncInfo.MST.addEdge(SrcBB, InstrBB, 0); - NewEdge.setEdgeCount(CountValue); - // Add new edge of InstrBB->DestBB. - PGOUseEdge &NewEdge1 = FuncInfo.MST.addEdge(InstrBB, DestBB, 0); - NewEdge1.setEdgeCount(CountValue); - NewEdge1.InMST = true; - getBBInfo(InstrBB).setBBInfoCount(CountValue); + UseBBInfo &Info = getBBInfo(InstrBB); + Info.setBBInfoCount(CountValue); } ProfileCountSize = CountFromProfile.size(); CountPosition = I; + + // Set the edge count and update the count of unknown edges for BBs. + auto setEdgeCount = [this](PGOUseEdge *E, uint64_t Value) -> void { + E->setEdgeCount(Value); + this->getBBInfo(E->SrcBB).UnknownCountOutEdge--; + this->getBBInfo(E->DestBB).UnknownCountInEdge--; + }; + + // Set the profile count the Instrumented edges. There are BBs that not in + // MST but not instrumented. Need to set the edge count value so that we can + // populate the profile counts later. + for (auto &E : FuncInfo.MST.AllEdges) { + if (E->Removed || E->InMST) + continue; + const BasicBlock *SrcBB = E->SrcBB; + UseBBInfo &SrcInfo = getBBInfo(SrcBB); + + // If only one out-edge, the edge profile count should be the same as BB + // profile count. + if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1) + setEdgeCount(E.get(), SrcInfo.CountValue); + else { + const BasicBlock *DestBB = E->DestBB; + UseBBInfo &DestInfo = getBBInfo(DestBB); + // If only one in-edge, the edge profile count should be the same as BB + // profile count. + if (DestInfo.CountValid && DestInfo.InEdges.size() == 1) + setEdgeCount(E.get(), DestInfo.CountValue); + } + if (E->CountValid) + continue; + // E's count should have been set from profile. If not, this meenas E skips + // the instrumentation. We set the count to 0. + setEdgeCount(E.get(), 0); + } + return true; } // Set the count value for the unknown edge. There should be one and only one @@ -1022,23 +1180,31 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros) handleAllErrors(std::move(E), [&](const InstrProfError &IPE) { auto Err = IPE.get(); bool SkipWarning = false; + LLVM_DEBUG(dbgs() << "Error in reading profile for Func " + << FuncInfo.FuncName << ": "); if (Err == instrprof_error::unknown_function) { - NumOfPGOMissing++; + IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++; SkipWarning = !PGOWarnMissing; + LLVM_DEBUG(dbgs() << "unknown function"); } else if (Err == instrprof_error::hash_mismatch || Err == instrprof_error::malformed) { - NumOfPGOMismatch++; + IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++; SkipWarning = NoPGOWarnMismatch || (NoPGOWarnMismatchComdat && (F.hasComdat() || F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); + LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")"); } + LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n"); if (SkipWarning) return; - std::string Msg = IPE.message() + std::string(" ") + F.getName().str(); + std::string Msg = IPE.message() + std::string(" ") + F.getName().str() + + std::string(" Hash = ") + + std::to_string(FuncInfo.FunctionHash); + Ctx.diagnose( DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning)); }); @@ -1047,7 +1213,7 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros) ProfileRecord = std::move(Result.get()); std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts; - NumOfPGOFunc++; + IsCS ? NumOfCSPGOFunc++ : NumOfPGOFunc++; LLVM_DEBUG(dbgs() << CountFromProfile.size() << " counts\n"); uint64_t ValueSum = 0; for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) { @@ -1061,34 +1227,23 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros) getBBInfo(nullptr).UnknownCountOutEdge = 2; getBBInfo(nullptr).UnknownCountInEdge = 2; - setInstrumentedCounts(CountFromProfile); - ProgramMaxCount = PGOReader->getMaximumFunctionCount(); + if (!setInstrumentedCounts(CountFromProfile)) { + LLVM_DEBUG( + dbgs() << "Inconsistent number of counts, skipping this function"); + Ctx.diagnose(DiagnosticInfoPGOProfile( + M->getName().data(), + Twine("Inconsistent number of counts in ") + F.getName().str() + + Twine(": the profile may be stale or there is a function name collision."), + DS_Warning)); + return false; + } + ProgramMaxCount = PGOReader->getMaximumFunctionCount(IsCS); return true; } // Populate the counters from instrumented BBs to all BBs. // In the end of this operation, all BBs should have a valid count value. void PGOUseFunc::populateCounters() { - // First set up Count variable for all BBs. - for (auto &E : FuncInfo.MST.AllEdges) { - if (E->Removed) - continue; - - const BasicBlock *SrcBB = E->SrcBB; - const BasicBlock *DestBB = E->DestBB; - UseBBInfo &SrcInfo = getBBInfo(SrcBB); - UseBBInfo &DestInfo = getBBInfo(DestBB); - SrcInfo.OutEdges.push_back(E.get()); - DestInfo.InEdges.push_back(E.get()); - SrcInfo.UnknownCountOutEdge++; - DestInfo.UnknownCountInEdge++; - - if (!E->CountValid) - continue; - DestInfo.UnknownCountInEdge--; - SrcInfo.UnknownCountOutEdge--; - } - bool Changes = true; unsigned NumPasses = 0; while (Changes) { @@ -1167,7 +1322,8 @@ void PGOUseFunc::populateCounters() { // Assign the scaled count values to the BB with multiple out edges. void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. - LLVM_DEBUG(dbgs() << "\nSetting branch weights.\n"); + LLVM_DEBUG(dbgs() << "\nSetting branch weights for func " << F.getName() + << " IsCS=" << IsCS << "\n"); for (auto &BB : F) { Instruction *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) @@ -1175,6 +1331,7 @@ void PGOUseFunc::setBranchWeights() { if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI))) continue; + if (getBBInfo(&BB).CountValue == 0) continue; @@ -1282,7 +1439,7 @@ void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) { Type *Int64Ty = Builder.getInt64Ty(); Type *I8PtrTy = Builder.getInt8PtrTy(); Value *Length = MI.getLength(); - assert(!dyn_cast<ConstantInt>(Length)); + assert(!isa<ConstantInt>(Length)); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), @@ -1325,8 +1482,14 @@ void PGOUseFunc::annotateValueSites() { annotateValueSites(Kind); } +static const char *ValueProfKindDescr[] = { +#define VALUE_PROF_KIND(Enumerator, Value, Descr) Descr, +#include "llvm/ProfileData/InstrProfData.inc" +}; + // Annotate the instructions for a specific value kind. void PGOUseFunc::annotateValueSites(uint32_t Kind) { + assert(Kind <= IPVK_Last); unsigned ValueSiteIndex = 0; auto &ValueSites = FuncInfo.ValueSites[Kind]; unsigned NumValueSites = ProfileRecord.getNumValueSites(Kind); @@ -1334,8 +1497,10 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { auto &Ctx = M->getContext(); Ctx.diagnose(DiagnosticInfoPGOProfile( M->getName().data(), - Twine("Inconsistent number of value sites for kind = ") + Twine(Kind) + - " in " + F.getName().str(), + Twine("Inconsistent number of value sites for ") + + Twine(ValueProfKindDescr[Kind]) + + Twine(" profiling in \"") + F.getName().str() + + Twine("\", possibly due to the use of a stale profile."), DS_Warning)); return; } @@ -1352,24 +1517,6 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { } } -// Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime -// aware this is an ir_level profile so it can set the version flag. -static void createIRLevelProfileFlagVariable(Module &M) { - Type *IntTy64 = Type::getInt64Ty(M.getContext()); - uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF); - auto IRLevelVersionVariable = new GlobalVariable( - M, IntTy64, true, GlobalVariable::ExternalLinkage, - Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), - INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)); - IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility); - Triple TT(M.getTargetTriple()); - if (!TT.supportsCOMDAT()) - IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage); - else - IRLevelVersionVariable->setComdat(M.getOrInsertComdat( - StringRef(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)))); -} - // Collect the set of members for each Comdat in module M and store // in ComdatMembers. static void collectComdatMembers( @@ -1390,8 +1537,11 @@ static void collectComdatMembers( static bool InstrumentAllFunctions( Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, - function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { - createIRLevelProfileFlagVariable(M); + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) { + // For the context-sensitve instrumentation, we should have a separated pass + // (before LTO/ThinLTO linking) to create these variables. + if (!IsCS) + createIRLevelProfileFlagVar(M, /* IsCS */ false); std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; collectComdatMembers(M, ComdatMembers); @@ -1400,11 +1550,18 @@ static bool InstrumentAllFunctions( continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers); + instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers, IsCS); } return true; } +PreservedAnalyses +PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &AM) { + createProfileFileNameVar(M, CSInstrName); + createIRLevelProfileFlagVar(M, /* IsCS */ true); + return PreservedAnalyses::all(); +} + bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { if (skipModule(M)) return false; @@ -1415,7 +1572,7 @@ bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { auto LookupBFI = [this](Function &F) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return InstrumentAllFunctions(M, LookupBPI, LookupBFI); + return InstrumentAllFunctions(M, LookupBPI, LookupBFI, IsCS); } PreservedAnalyses PGOInstrumentationGen::run(Module &M, @@ -1429,7 +1586,7 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M, return &FAM.getResult<BlockFrequencyAnalysis>(F); }; - if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI)) + if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI, IsCS)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -1438,7 +1595,7 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M, static bool annotateAllFunctions( Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, - function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { + function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) { LLVM_DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. @@ -1459,6 +1616,9 @@ static bool annotateAllFunctions( StringRef("Cannot get PGOReader"))); return false; } + if (!PGOReader->hasCSIRLevelProfile() && IsCS) + return false; + // TODO: might need to change the warning once the clang option is finalized. if (!PGOReader->isIRLevelProfile()) { Ctx.diagnose(DiagnosticInfoPGOProfile( @@ -1478,7 +1638,7 @@ static bool annotateAllFunctions( // Split indirectbr critical edges here before computing the MST rather than // later in getInstrBB() to avoid invalidating it. SplitIndirectBrCriticalEdges(F, BPI, BFI); - PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI); + PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, IsCS); bool AllZeros = false; if (!Func.readCounters(PGOReader.get(), AllZeros)) continue; @@ -1526,7 +1686,10 @@ static bool annotateAllFunctions( } } } - M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext())); + M.setProfileSummary(PGOReader->getSummary(IsCS).getMD(M.getContext()), + IsCS ? ProfileSummary::PSK_CSInstr + : ProfileSummary::PSK_Instr); + // Set function hotness attribute from the profile. // We have to apply these attributes at the end because their presence // can affect the BranchProbabilityInfo of any callers, resulting in an @@ -1545,9 +1708,10 @@ static bool annotateAllFunctions( } PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename, - std::string RemappingFilename) + std::string RemappingFilename, + bool IsCS) : ProfileFileName(std::move(Filename)), - ProfileRemappingFileName(std::move(RemappingFilename)) { + ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS) { if (!PGOTestProfileFile.empty()) ProfileFileName = PGOTestProfileFile; if (!PGOTestProfileRemappingFile.empty()) @@ -1567,7 +1731,7 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, }; if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, - LookupBPI, LookupBFI)) + LookupBPI, LookupBFI, IsCS)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -1584,7 +1748,8 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI); + return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI, + IsCS); } static std::string getSimpleNodeName(const BasicBlock *Node) { diff --git a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 2c71e75dadcc..188f95b4676b 100644 --- a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -1,9 +1,8 @@ //===-- PGOMemOPSizeOpt.cpp - Optimizations based on value profiling ===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -20,12 +19,12 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" diff --git a/lib/Transforms/Instrumentation/PoisonChecking.cpp b/lib/Transforms/Instrumentation/PoisonChecking.cpp new file mode 100644 index 000000000000..81d92e724c7d --- /dev/null +++ b/lib/Transforms/Instrumentation/PoisonChecking.cpp @@ -0,0 +1,357 @@ +//===- PoisonChecking.cpp - -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements a transform pass which instruments IR such that poison semantics +// are made explicit. That is, it provides a (possibly partial) executable +// semantics for every instruction w.r.t. poison as specified in the LLVM +// LangRef. There are obvious parallels to the sanitizer tools, but this pass +// is focused purely on the semantics of LLVM IR, not any particular source +// language. If you're looking for something to see if your C/C++ contains +// UB, this is not it. +// +// The rewritten semantics of each instruction will include the following +// components: +// +// 1) The original instruction, unmodified. +// 2) A propagation rule which translates dynamic information about the poison +// state of each input to whether the dynamic output of the instruction +// produces poison. +// 3) A flag validation rule which validates any poison producing flags on the +// instruction itself (e.g. checks for overflow on nsw). +// 4) A check rule which traps (to a handler function) if this instruction must +// execute undefined behavior given the poison state of it's inputs. +// +// At the moment, the UB detection is done in a best effort manner; that is, +// the resulting code may produce a false negative result (not report UB when +// it actually exists according to the LangRef spec), but should never produce +// a false positive (report UB where it doesn't exist). The intention is to +// eventually support a "strict" mode which never dynamically reports a false +// negative at the cost of rejecting some valid inputs to translation. +// +// Use cases for this pass include: +// - Understanding (and testing!) the implications of the definition of poison +// from the LangRef. +// - Validating the output of a IR fuzzer to ensure that all programs produced +// are well defined on the specific input used. +// - Finding/confirming poison specific miscompiles by checking the poison +// status of an input/IR pair is the same before and after an optimization +// transform. +// - Checking that a bugpoint reduction does not introduce UB which didn't +// exist in the original program being reduced. +// +// The major sources of inaccuracy are currently: +// - Most validation rules not yet implemented for instructions with poison +// relavant flags. At the moment, only nsw/nuw on add/sub are supported. +// - UB which is control dependent on a branch on poison is not yet +// reported. Currently, only data flow dependence is modeled. +// - Poison which is propagated through memory is not modeled. As such, +// storing poison to memory and then reloading it will cause a false negative +// as we consider the reloaded value to not be poisoned. +// - Poison propagation across function boundaries is not modeled. At the +// moment, all arguments and return values are assumed not to be poison. +// - Undef is not modeled. In particular, the optimizer's freedom to pick +// concrete values for undef bits so as to maximize potential for producing +// poison is not modeled. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/PoisonChecking.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "poison-checking" + +static cl::opt<bool> +LocalCheck("poison-checking-function-local", + cl::init(false), + cl::desc("Check that returns are non-poison (for testing)")); + + +static bool isConstantFalse(Value* V) { + assert(V->getType()->isIntegerTy(1)); + if (auto *CI = dyn_cast<ConstantInt>(V)) + return CI->isZero(); + return false; +} + +static Value *buildOrChain(IRBuilder<> &B, ArrayRef<Value*> Ops) { + if (Ops.size() == 0) + return B.getFalse(); + unsigned i = 0; + for (; i < Ops.size() && isConstantFalse(Ops[i]); i++) {} + if (i == Ops.size()) + return B.getFalse(); + Value *Accum = Ops[i++]; + for (; i < Ops.size(); i++) + if (!isConstantFalse(Ops[i])) + Accum = B.CreateOr(Accum, Ops[i]); + return Accum; +} + +static void generatePoisonChecksForBinOp(Instruction &I, + SmallVector<Value*, 2> &Checks) { + assert(isa<BinaryOperator>(I)); + + IRBuilder<> B(&I); + Value *LHS = I.getOperand(0); + Value *RHS = I.getOperand(1); + switch (I.getOpcode()) { + default: + return; + case Instruction::Add: { + if (I.hasNoSignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::sadd_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + if (I.hasNoUnsignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::uadd_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + break; + } + case Instruction::Sub: { + if (I.hasNoSignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::ssub_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + if (I.hasNoUnsignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::usub_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + break; + } + case Instruction::Mul: { + if (I.hasNoSignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::smul_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + if (I.hasNoUnsignedWrap()) { + auto *OverflowOp = + B.CreateBinaryIntrinsic(Intrinsic::umul_with_overflow, LHS, RHS); + Checks.push_back(B.CreateExtractValue(OverflowOp, 1)); + } + break; + } + case Instruction::UDiv: { + if (I.isExact()) { + auto *Check = + B.CreateICmp(ICmpInst::ICMP_NE, B.CreateURem(LHS, RHS), + ConstantInt::get(LHS->getType(), 0)); + Checks.push_back(Check); + } + break; + } + case Instruction::SDiv: { + if (I.isExact()) { + auto *Check = + B.CreateICmp(ICmpInst::ICMP_NE, B.CreateSRem(LHS, RHS), + ConstantInt::get(LHS->getType(), 0)); + Checks.push_back(Check); + } + break; + } + case Instruction::AShr: + case Instruction::LShr: + case Instruction::Shl: { + Value *ShiftCheck = + B.CreateICmp(ICmpInst::ICMP_UGE, RHS, + ConstantInt::get(RHS->getType(), + LHS->getType()->getScalarSizeInBits())); + Checks.push_back(ShiftCheck); + break; + } + }; +} + +static Value* generatePoisonChecks(Instruction &I) { + IRBuilder<> B(&I); + SmallVector<Value*, 2> Checks; + if (isa<BinaryOperator>(I) && !I.getType()->isVectorTy()) + generatePoisonChecksForBinOp(I, Checks); + + // Handle non-binops seperately + switch (I.getOpcode()) { + default: + break; + case Instruction::ExtractElement: { + Value *Vec = I.getOperand(0); + if (Vec->getType()->getVectorIsScalable()) + break; + Value *Idx = I.getOperand(1); + unsigned NumElts = Vec->getType()->getVectorNumElements(); + Value *Check = + B.CreateICmp(ICmpInst::ICMP_UGE, Idx, + ConstantInt::get(Idx->getType(), NumElts)); + Checks.push_back(Check); + break; + } + case Instruction::InsertElement: { + Value *Vec = I.getOperand(0); + if (Vec->getType()->getVectorIsScalable()) + break; + Value *Idx = I.getOperand(2); + unsigned NumElts = Vec->getType()->getVectorNumElements(); + Value *Check = + B.CreateICmp(ICmpInst::ICMP_UGE, Idx, + ConstantInt::get(Idx->getType(), NumElts)); + Checks.push_back(Check); + break; + } + }; + return buildOrChain(B, Checks); +} + +static Value *getPoisonFor(DenseMap<Value *, Value *> &ValToPoison, Value *V) { + auto Itr = ValToPoison.find(V); + if (Itr != ValToPoison.end()) + return Itr->second; + if (isa<Constant>(V)) { + return ConstantInt::getFalse(V->getContext()); + } + // Return false for unknwon values - this implements a non-strict mode where + // unhandled IR constructs are simply considered to never produce poison. At + // some point in the future, we probably want a "strict mode" for testing if + // nothing else. + return ConstantInt::getFalse(V->getContext()); +} + +static void CreateAssert(IRBuilder<> &B, Value *Cond) { + assert(Cond->getType()->isIntegerTy(1)); + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + if (CI->isAllOnesValue()) + return; + + Module *M = B.GetInsertBlock()->getModule(); + M->getOrInsertFunction("__poison_checker_assert", + Type::getVoidTy(M->getContext()), + Type::getInt1Ty(M->getContext())); + Function *TrapFunc = M->getFunction("__poison_checker_assert"); + B.CreateCall(TrapFunc, Cond); +} + +static void CreateAssertNot(IRBuilder<> &B, Value *Cond) { + assert(Cond->getType()->isIntegerTy(1)); + CreateAssert(B, B.CreateNot(Cond)); +} + +static bool rewrite(Function &F) { + auto * const Int1Ty = Type::getInt1Ty(F.getContext()); + + DenseMap<Value *, Value *> ValToPoison; + + for (BasicBlock &BB : F) + for (auto I = BB.begin(); isa<PHINode>(&*I); I++) { + auto *OldPHI = cast<PHINode>(&*I); + auto *NewPHI = PHINode::Create(Int1Ty, + OldPHI->getNumIncomingValues()); + for (unsigned i = 0; i < OldPHI->getNumIncomingValues(); i++) + NewPHI->addIncoming(UndefValue::get(Int1Ty), + OldPHI->getIncomingBlock(i)); + NewPHI->insertBefore(OldPHI); + ValToPoison[OldPHI] = NewPHI; + } + + for (BasicBlock &BB : F) + for (Instruction &I : BB) { + if (isa<PHINode>(I)) continue; + + IRBuilder<> B(cast<Instruction>(&I)); + + // Note: There are many more sources of documented UB, but this pass only + // attempts to find UB triggered by propagation of poison. + if (Value *Op = const_cast<Value*>(getGuaranteedNonFullPoisonOp(&I))) + CreateAssertNot(B, getPoisonFor(ValToPoison, Op)); + + if (LocalCheck) + if (auto *RI = dyn_cast<ReturnInst>(&I)) + if (RI->getNumOperands() != 0) { + Value *Op = RI->getOperand(0); + CreateAssertNot(B, getPoisonFor(ValToPoison, Op)); + } + + SmallVector<Value*, 4> Checks; + if (propagatesFullPoison(&I)) + for (Value *V : I.operands()) + Checks.push_back(getPoisonFor(ValToPoison, V)); + + if (auto *Check = generatePoisonChecks(I)) + Checks.push_back(Check); + ValToPoison[&I] = buildOrChain(B, Checks); + } + + for (BasicBlock &BB : F) + for (auto I = BB.begin(); isa<PHINode>(&*I); I++) { + auto *OldPHI = cast<PHINode>(&*I); + if (!ValToPoison.count(OldPHI)) + continue; // skip the newly inserted phis + auto *NewPHI = cast<PHINode>(ValToPoison[OldPHI]); + for (unsigned i = 0; i < OldPHI->getNumIncomingValues(); i++) { + auto *OldVal = OldPHI->getIncomingValue(i); + NewPHI->setIncomingValue(i, getPoisonFor(ValToPoison, OldVal)); + } + } + return true; +} + + +PreservedAnalyses PoisonCheckingPass::run(Module &M, + ModuleAnalysisManager &AM) { + bool Changed = false; + for (auto &F : M) + Changed |= rewrite(F); + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +PreservedAnalyses PoisonCheckingPass::run(Function &F, + FunctionAnalysisManager &AM) { + return rewrite(F) ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + + +/* Major TODO Items: + - Control dependent poison UB + - Strict mode - (i.e. must analyze every operand) + - Poison through memory + - Function ABIs + - Full coverage of intrinsics, etc.. (ouch) + + Instructions w/Unclear Semantics: + - shufflevector - It would seem reasonable for an out of bounds mask element + to produce poison, but the LangRef does not state. + - and/or - It would seem reasonable for poison to propagate from both + arguments, but LangRef doesn't state and propagatesFullPoison doesn't + include these two. + - all binary ops w/vector operands - The likely interpretation would be that + any element overflowing should produce poison for the entire result, but + the LangRef does not state. + - Floating point binary ops w/fmf flags other than (nnan, noinfs). It seems + strange that only certian flags should be documented as producing poison. + + Cases of clear poison semantics not yet implemented: + - Exact flags on ashr/lshr produce poison + - NSW/NUW flags on shl produce poison + - Inbounds flag on getelementptr produce poison + - fptosi/fptoui (out of bounds input) produce poison + - Scalable vector types for insertelement/extractelement + - Floating point binary ops w/fmf nnan/noinfs flags produce poison + */ diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 0ba8d5765e8c..ca0cb4bdbe84 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -1,9 +1,8 @@ //===-- SanitizerCoverage.cpp - coverage instrumentation for sanitizers ---===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -62,7 +61,10 @@ static const char *const SanCovTraceDiv4 = "__sanitizer_cov_trace_div4"; static const char *const SanCovTraceDiv8 = "__sanitizer_cov_trace_div8"; static const char *const SanCovTraceGep = "__sanitizer_cov_trace_gep"; static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch"; -static const char *const SanCovModuleCtorName = "sancov.module_ctor"; +static const char *const SanCovModuleCtorTracePcGuardName = + "sancov.module_ctor_trace_pc_guard"; +static const char *const SanCovModuleCtor8bitCountersName = + "sancov.module_ctor_8bit_counters"; static const uint64_t SanCtorAndDtorPriority = 2; static const char *const SanCovTracePCGuardName = @@ -210,8 +212,9 @@ private: void CreateFunctionLocalArrays(Function &F, ArrayRef<BasicBlock *> AllBlocks); void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx, bool IsLeafFunc = true); - Function *CreateInitCallsForSections(Module &M, const char *InitFunctionName, - Type *Ty, const char *Section); + Function *CreateInitCallsForSections(Module &M, const char *CtorName, + const char *InitFunctionName, Type *Ty, + const char *Section); std::pair<Value *, Value *> CreateSecStartEnd(Module &M, const char *Section, Type *Ty); @@ -223,13 +226,13 @@ private: std::string getSectionName(const std::string &Section) const; std::string getSectionStart(const std::string &Section) const; std::string getSectionEnd(const std::string &Section) const; - Function *SanCovTracePCIndir; - Function *SanCovTracePC, *SanCovTracePCGuard; - Function *SanCovTraceCmpFunction[4]; - Function *SanCovTraceConstCmpFunction[4]; - Function *SanCovTraceDivFunction[2]; - Function *SanCovTraceGepFunction; - Function *SanCovTraceSwitchFunction; + FunctionCallee SanCovTracePCIndir; + FunctionCallee SanCovTracePC, SanCovTracePCGuard; + FunctionCallee SanCovTraceCmpFunction[4]; + FunctionCallee SanCovTraceConstCmpFunction[4]; + FunctionCallee SanCovTraceDivFunction[2]; + FunctionCallee SanCovTraceGepFunction; + FunctionCallee SanCovTraceSwitchFunction; GlobalVariable *SanCovLowestStack; InlineAsm *EmptyAsm; Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy, @@ -270,24 +273,25 @@ SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, // Account for the fact that on windows-msvc __start_* symbols actually // point to a uint64_t before the start of the array. auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, Int8PtrTy); - auto GEP = IRB.CreateGEP(SecStartI8Ptr, + auto GEP = IRB.CreateGEP(Int8Ty, SecStartI8Ptr, ConstantInt::get(IntptrTy, sizeof(uint64_t))); return std::make_pair(IRB.CreatePointerCast(GEP, Ty), SecEndPtr); } Function *SanitizerCoverageModule::CreateInitCallsForSections( - Module &M, const char *InitFunctionName, Type *Ty, + Module &M, const char *CtorName, const char *InitFunctionName, Type *Ty, const char *Section) { auto SecStartEnd = CreateSecStartEnd(M, Section, Ty); auto SecStart = SecStartEnd.first; auto SecEnd = SecStartEnd.second; Function *CtorFunc; std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( - M, SanCovModuleCtorName, InitFunctionName, {Ty, Ty}, {SecStart, SecEnd}); + M, CtorName, InitFunctionName, {Ty, Ty}, {SecStart, SecEnd}); + assert(CtorFunc->getName() == CtorName); if (TargetTriple.supportsCOMDAT()) { // Use comdat to dedup CtorFunc. - CtorFunc->setComdat(M.getOrInsertComdat(SanCovModuleCtorName)); + CtorFunc->setComdat(M.getOrInsertComdat(CtorName)); appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority, CtorFunc); } else { appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); @@ -329,77 +333,74 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Int16Ty = IRB.getInt16Ty(); Int8Ty = IRB.getInt8Ty(); - SanCovTracePCIndir = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy)); + SanCovTracePCIndir = + M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy); + // Make sure smaller parameters are zero-extended to i64 as required by the + // x86_64 ABI. + AttributeList SanCovTraceCmpZeroExtAL; + if (TargetTriple.getArch() == Triple::x86_64) { + SanCovTraceCmpZeroExtAL = + SanCovTraceCmpZeroExtAL.addParamAttribute(*C, 0, Attribute::ZExt); + SanCovTraceCmpZeroExtAL = + SanCovTraceCmpZeroExtAL.addParamAttribute(*C, 1, Attribute::ZExt); + } + SanCovTraceCmpFunction[0] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty())); - SanCovTraceCmpFunction[1] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTraceCmp2, VoidTy, IRB.getInt16Ty(), - IRB.getInt16Ty())); - SanCovTraceCmpFunction[2] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTraceCmp4, VoidTy, IRB.getInt32Ty(), - IRB.getInt32Ty())); + M.getOrInsertFunction(SanCovTraceCmp1, SanCovTraceCmpZeroExtAL, VoidTy, + IRB.getInt8Ty(), IRB.getInt8Ty()); + SanCovTraceCmpFunction[1] = + M.getOrInsertFunction(SanCovTraceCmp2, SanCovTraceCmpZeroExtAL, VoidTy, + IRB.getInt16Ty(), IRB.getInt16Ty()); + SanCovTraceCmpFunction[2] = + M.getOrInsertFunction(SanCovTraceCmp4, SanCovTraceCmpZeroExtAL, VoidTy, + IRB.getInt32Ty(), IRB.getInt32Ty()); SanCovTraceCmpFunction[3] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty)); - - SanCovTraceConstCmpFunction[0] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceConstCmp1, VoidTy, Int8Ty, Int8Ty)); - SanCovTraceConstCmpFunction[1] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceConstCmp2, VoidTy, Int16Ty, Int16Ty)); - SanCovTraceConstCmpFunction[2] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceConstCmp4, VoidTy, Int32Ty, Int32Ty)); + M.getOrInsertFunction(SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty); + + SanCovTraceConstCmpFunction[0] = M.getOrInsertFunction( + SanCovTraceConstCmp1, SanCovTraceCmpZeroExtAL, VoidTy, Int8Ty, Int8Ty); + SanCovTraceConstCmpFunction[1] = M.getOrInsertFunction( + SanCovTraceConstCmp2, SanCovTraceCmpZeroExtAL, VoidTy, Int16Ty, Int16Ty); + SanCovTraceConstCmpFunction[2] = M.getOrInsertFunction( + SanCovTraceConstCmp4, SanCovTraceCmpZeroExtAL, VoidTy, Int32Ty, Int32Ty); SanCovTraceConstCmpFunction[3] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty)); - - SanCovTraceDivFunction[0] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceDiv4, VoidTy, IRB.getInt32Ty())); + M.getOrInsertFunction(SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty); + + { + AttributeList AL; + if (TargetTriple.getArch() == Triple::x86_64) + AL = AL.addParamAttribute(*C, 0, Attribute::ZExt); + SanCovTraceDivFunction[0] = + M.getOrInsertFunction(SanCovTraceDiv4, AL, VoidTy, IRB.getInt32Ty()); + } SanCovTraceDivFunction[1] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceDiv8, VoidTy, Int64Ty)); + M.getOrInsertFunction(SanCovTraceDiv8, VoidTy, Int64Ty); SanCovTraceGepFunction = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceGep, VoidTy, IntptrTy)); + M.getOrInsertFunction(SanCovTraceGep, VoidTy, IntptrTy); SanCovTraceSwitchFunction = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy)); + M.getOrInsertFunction(SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy); Constant *SanCovLowestStackConstant = M.getOrInsertGlobal(SanCovLowestStackName, IntptrTy); - SanCovLowestStack = cast<GlobalVariable>(SanCovLowestStackConstant); + SanCovLowestStack = dyn_cast<GlobalVariable>(SanCovLowestStackConstant); + if (!SanCovLowestStack) { + C->emitError(StringRef("'") + SanCovLowestStackName + + "' should not be declared by the user"); + return true; + } SanCovLowestStack->setThreadLocalMode( GlobalValue::ThreadLocalMode::InitialExecTLSModel); if (Options.StackDepth && !SanCovLowestStack->isDeclaration()) SanCovLowestStack->setInitializer(Constant::getAllOnesValue(IntptrTy)); - // Make sure smaller parameters are zero-extended to i64 as required by the - // x86_64 ABI. - if (TargetTriple.getArch() == Triple::x86_64) { - for (int i = 0; i < 3; i++) { - SanCovTraceCmpFunction[i]->addParamAttr(0, Attribute::ZExt); - SanCovTraceCmpFunction[i]->addParamAttr(1, Attribute::ZExt); - SanCovTraceConstCmpFunction[i]->addParamAttr(0, Attribute::ZExt); - SanCovTraceConstCmpFunction[i]->addParamAttr(1, Attribute::ZExt); - } - SanCovTraceDivFunction[0]->addParamAttr(0, Attribute::ZExt); - } - - // We insert an empty inline asm after cov callbacks to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), /*hasSideEffects=*/true); - SanCovTracePC = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(SanCovTracePCName, VoidTy)); - SanCovTracePCGuard = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTracePCGuardName, VoidTy, Int32PtrTy)); + SanCovTracePC = M.getOrInsertFunction(SanCovTracePCName, VoidTy); + SanCovTracePCGuard = + M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy); for (auto &F : M) runOnFunction(F); @@ -407,14 +408,16 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Function *Ctor = nullptr; if (FunctionGuardArray) - Ctor = CreateInitCallsForSections(M, SanCovTracePCGuardInitName, Int32PtrTy, + Ctor = CreateInitCallsForSections(M, SanCovModuleCtorTracePcGuardName, + SanCovTracePCGuardInitName, Int32PtrTy, SanCovGuardsSectionName); if (Function8bitCounterArray) - Ctor = CreateInitCallsForSections(M, SanCov8bitCountersInitName, Int8PtrTy, + Ctor = CreateInitCallsForSections(M, SanCovModuleCtor8bitCountersName, + SanCov8bitCountersInitName, Int8PtrTy, SanCovCountersSectionName); if (Ctor && Options.PCTable) { auto SecStartEnd = CreateSecStartEnd(M, SanCovPCsSectionName, IntptrPtrTy); - Function *InitFunction = declareSanitizerInitFunction( + FunctionCallee InitFunction = declareSanitizerInitFunction( M, SanCovPCsInitName, {IntptrPtrTy, IntptrPtrTy}); IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); @@ -458,12 +461,12 @@ static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, const DominatorTree *DT, const PostDominatorTree *PDT, const SanitizerCoverageOptions &Options) { - // Don't insert coverage for unreachable blocks: we will never call - // __sanitizer_cov() for them, so counting them in + // Don't insert coverage for blocks containing nothing but unreachable: we + // will never call __sanitizer_cov() for them, so counting them in // NumberOfInstrumentedBlocks() might complicate calculation of code coverage // percentage. Also, unreachable instructions frequently have no debug // locations. - if (isa<UnreachableInst>(BB->getTerminator())) + if (isa<UnreachableInst>(BB->getFirstNonPHIOrDbgOrLifetime())) return false; // Don't insert coverage into blocks without a valid insertion point @@ -484,6 +487,37 @@ static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, && !(isFullPostDominator(BB, PDT) && !BB->getSinglePredecessor()); } + +// Returns true iff From->To is a backedge. +// A twist here is that we treat From->To as a backedge if +// * To dominates From or +// * To->UniqueSuccessor dominates From +static bool IsBackEdge(BasicBlock *From, BasicBlock *To, + const DominatorTree *DT) { + if (DT->dominates(To, From)) + return true; + if (auto Next = To->getUniqueSuccessor()) + if (DT->dominates(Next, From)) + return true; + return false; +} + +// Prunes uninteresting Cmp instrumentation: +// * CMP instructions that feed into loop backedge branch. +// +// Note that Cmp pruning is controlled by the same flag as the +// BB pruning. +static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree *DT, + const SanitizerCoverageOptions &Options) { + if (!Options.NoPrune) + if (CMP->hasOneUse()) + if (auto BR = dyn_cast<BranchInst>(CMP->user_back())) + for (BasicBlock *B : BR->successors()) + if (IsBackEdge(BR->getParent(), B, DT)) + return false; + return true; +} + bool SanitizerCoverageModule::runOnFunction(Function &F) { if (F.empty()) return false; @@ -508,7 +542,7 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { isAsynchronousEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) return false; if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge) - SplitAllCriticalEdges(F); + SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests()); SmallVector<Instruction *, 8> IndirCalls; SmallVector<BasicBlock *, 16> BlocksToInstrument; SmallVector<Instruction *, 8> CmpTraceTargets; @@ -532,8 +566,9 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { IndirCalls.push_back(&Inst); } if (Options.TraceCmp) { - if (isa<ICmpInst>(&Inst)) - CmpTraceTargets.push_back(&Inst); + if (ICmpInst *CMP = dyn_cast<ICmpInst>(&Inst)) + if (IsInterestingCmp(CMP, DT, Options)) + CmpTraceTargets.push_back(&Inst); if (isa<SwitchInst>(&Inst)) SwitchTraceTargets.push_back(&Inst); } @@ -797,9 +832,9 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, } if (Options.Inline8bitCounters) { auto CounterPtr = IRB.CreateGEP( - Function8bitCounterArray, + Function8bitCounterArray->getValueType(), Function8bitCounterArray, {ConstantInt::get(IntptrTy, 0), ConstantInt::get(IntptrTy, Idx)}); - auto Load = IRB.CreateLoad(CounterPtr); + auto Load = IRB.CreateLoad(Int8Ty, CounterPtr); auto Inc = IRB.CreateAdd(Load, ConstantInt::get(Int8Ty, 1)); auto Store = IRB.CreateStore(Inc, CounterPtr); SetNoSanitizeMetadata(Load); @@ -812,7 +847,7 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, auto FrameAddrPtr = IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)}); auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy); - auto LowestStack = IRB.CreateLoad(SanCovLowestStack); + auto LowestStack = IRB.CreateLoad(IntptrTy, SanCovLowestStack); auto IsStackLower = IRB.CreateICmpULT(FrameAddrInt, LowestStack); auto ThenTerm = SplitBlockAndInsertIfThen(IsStackLower, &*IP, false); IRBuilder<> ThenIRB(ThenTerm); diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 077364e15c4f..5be13fa745cb 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -1,9 +1,8 @@ //===-- ThreadSanitizer.cpp - race detector -------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -111,25 +110,26 @@ private: Type *IntptrTy; IntegerType *OrdTy; // Callbacks to run-time library are computed in doInitialization. - Function *TsanFuncEntry; - Function *TsanFuncExit; - Function *TsanIgnoreBegin; - Function *TsanIgnoreEnd; + FunctionCallee TsanFuncEntry; + FunctionCallee TsanFuncExit; + FunctionCallee TsanIgnoreBegin; + FunctionCallee TsanIgnoreEnd; // Accesses sizes are powers of two: 1, 2, 4, 8, 16. static const size_t kNumberOfAccessSizes = 5; - Function *TsanRead[kNumberOfAccessSizes]; - Function *TsanWrite[kNumberOfAccessSizes]; - Function *TsanUnalignedRead[kNumberOfAccessSizes]; - Function *TsanUnalignedWrite[kNumberOfAccessSizes]; - Function *TsanAtomicLoad[kNumberOfAccessSizes]; - Function *TsanAtomicStore[kNumberOfAccessSizes]; - Function *TsanAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes]; - Function *TsanAtomicCAS[kNumberOfAccessSizes]; - Function *TsanAtomicThreadFence; - Function *TsanAtomicSignalFence; - Function *TsanVptrUpdate; - Function *TsanVptrLoad; - Function *MemmoveFn, *MemcpyFn, *MemsetFn; + FunctionCallee TsanRead[kNumberOfAccessSizes]; + FunctionCallee TsanWrite[kNumberOfAccessSizes]; + FunctionCallee TsanUnalignedRead[kNumberOfAccessSizes]; + FunctionCallee TsanUnalignedWrite[kNumberOfAccessSizes]; + FunctionCallee TsanAtomicLoad[kNumberOfAccessSizes]; + FunctionCallee TsanAtomicStore[kNumberOfAccessSizes]; + FunctionCallee TsanAtomicRMW[AtomicRMWInst::LAST_BINOP + 1] + [kNumberOfAccessSizes]; + FunctionCallee TsanAtomicCAS[kNumberOfAccessSizes]; + FunctionCallee TsanAtomicThreadFence; + FunctionCallee TsanAtomicSignalFence; + FunctionCallee TsanVptrUpdate; + FunctionCallee TsanVptrLoad; + FunctionCallee MemmoveFn, MemcpyFn, MemsetFn; Function *TsanCtorFunction; }; @@ -189,14 +189,14 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { Attr = Attr.addAttribute(M.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); // Initialize the callbacks. - TsanFuncEntry = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); - TsanFuncExit = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy())); - TsanIgnoreBegin = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy())); - TsanIgnoreEnd = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_ignore_thread_end", Attr, IRB.getVoidTy())); + TsanFuncEntry = M.getOrInsertFunction("__tsan_func_entry", Attr, + IRB.getVoidTy(), IRB.getInt8PtrTy()); + TsanFuncExit = + M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy()); + TsanIgnoreBegin = M.getOrInsertFunction("__tsan_ignore_thread_begin", Attr, + IRB.getVoidTy()); + TsanIgnoreEnd = + M.getOrInsertFunction("__tsan_ignore_thread_end", Attr, IRB.getVoidTy()); OrdTy = IRB.getInt32Ty(); for (size_t i = 0; i < kNumberOfAccessSizes; ++i) { const unsigned ByteSize = 1U << i; @@ -204,32 +204,30 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { std::string ByteSizeStr = utostr(ByteSize); std::string BitSizeStr = utostr(BitSize); SmallString<32> ReadName("__tsan_read" + ByteSizeStr); - TsanRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); + TsanRead[i] = M.getOrInsertFunction(ReadName, Attr, IRB.getVoidTy(), + IRB.getInt8PtrTy()); SmallString<32> WriteName("__tsan_write" + ByteSizeStr); - TsanWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); + TsanWrite[i] = M.getOrInsertFunction(WriteName, Attr, IRB.getVoidTy(), + IRB.getInt8PtrTy()); SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr); - TsanUnalignedRead[i] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); + TsanUnalignedRead[i] = M.getOrInsertFunction( + UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr); - TsanUnalignedWrite[i] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); + TsanUnalignedWrite[i] = M.getOrInsertFunction( + UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy()); Type *Ty = Type::getIntNTy(M.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load"); - TsanAtomicLoad[i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy)); + TsanAtomicLoad[i] = + M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy); SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store"); - TsanAtomicStore[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy)); + TsanAtomicStore[i] = M.getOrInsertFunction( + AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy); for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) { @@ -252,34 +250,34 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { else continue; SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart); - TsanAtomicRMW[op][i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy)); + TsanAtomicRMW[op][i] = + M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy); } SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr + "_compare_exchange_val"); - TsanAtomicCAS[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy)); + TsanAtomicCAS[i] = M.getOrInsertFunction(AtomicCASName, Attr, Ty, PtrTy, Ty, + Ty, OrdTy, OrdTy); } - TsanVptrUpdate = checkSanitizerInterfaceFunction( + TsanVptrUpdate = M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), - IRB.getInt8PtrTy(), IRB.getInt8PtrTy())); - TsanVptrLoad = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy())); - TsanAtomicThreadFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy)); - TsanAtomicSignalFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy)); - - MemmoveFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy)); - MemcpyFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt8PtrTy(), IntptrTy)); - MemsetFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), - IRB.getInt32Ty(), IntptrTy)); + IRB.getInt8PtrTy(), IRB.getInt8PtrTy()); + TsanVptrLoad = M.getOrInsertFunction("__tsan_vptr_read", Attr, + IRB.getVoidTy(), IRB.getInt8PtrTy()); + TsanAtomicThreadFence = M.getOrInsertFunction("__tsan_atomic_thread_fence", + Attr, IRB.getVoidTy(), OrdTy); + TsanAtomicSignalFence = M.getOrInsertFunction("__tsan_atomic_signal_fence", + Attr, IRB.getVoidTy(), OrdTy); + + MemmoveFn = + M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + MemcpyFn = + M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); + MemsetFn = + M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); } ThreadSanitizer::ThreadSanitizer(Module &M) { @@ -291,7 +289,9 @@ ThreadSanitizer::ThreadSanitizer(Module &M) { /*InitArgs=*/{}, // This callback is invoked when the functions are created the first // time. Hook them into the global ctors list in that case: - [&](Function *Ctor, Function *) { appendToGlobalCtors(M, Ctor, 0); }); + [&](Function *Ctor, FunctionCallee) { + appendToGlobalCtors(M, Ctor, 0); + }); } static bool isVtableAccess(Instruction *I) { @@ -559,7 +559,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(Instruction *I, : cast<LoadInst>(I)->getAlignment(); Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); const uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); - Value *OnAccessFunc = nullptr; + FunctionCallee OnAccessFunc = nullptr; if (Alignment == 0 || Alignment >= 8 || (Alignment % (TypeSize / 8)) == 0) OnAccessFunc = IsWrite ? TsanWrite[Idx] : TsanRead[Idx]; else @@ -659,7 +659,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { int Idx = getMemoryAccessFuncIndex(Addr, DL); if (Idx < 0) return false; - Function *F = TsanAtomicRMW[RMWI->getOperation()][Idx]; + FunctionCallee F = TsanAtomicRMW[RMWI->getOperation()][Idx]; if (!F) return false; const unsigned ByteSize = 1U << Idx; @@ -706,8 +706,9 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { I->eraseFromParent(); } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) { Value *Args[] = {createOrdering(&IRB, FI->getOrdering())}; - Function *F = FI->getSyncScopeID() == SyncScope::SingleThread ? - TsanAtomicSignalFence : TsanAtomicThreadFence; + FunctionCallee F = FI->getSyncScopeID() == SyncScope::SingleThread + ? TsanAtomicSignalFence + : TsanAtomicThreadFence; CallInst *C = CallInst::Create(F, Args); ReplaceInstWithInst(I, C); } diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index 7f6b157304a3..e1e95cd6a407 100644 --- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -1,9 +1,8 @@ //===- ARCRuntimeEntryPoints.h - ObjC ARC Optimization ----------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -34,7 +33,7 @@ namespace llvm { -class Constant; +class Function; class LLVMContext; namespace objcarc { @@ -70,7 +69,7 @@ public: RetainAutoreleaseRV = nullptr; } - Constant *get(ARCRuntimeEntryPointKind kind) { + Function *get(ARCRuntimeEntryPointKind kind) { assert(TheModule != nullptr && "Not initialized."); switch (kind) { @@ -106,33 +105,33 @@ private: Module *TheModule = nullptr; /// Declaration for ObjC runtime function objc_autoreleaseReturnValue. - Constant *AutoreleaseRV = nullptr; + Function *AutoreleaseRV = nullptr; /// Declaration for ObjC runtime function objc_release. - Constant *Release = nullptr; + Function *Release = nullptr; /// Declaration for ObjC runtime function objc_retain. - Constant *Retain = nullptr; + Function *Retain = nullptr; /// Declaration for ObjC runtime function objc_retainBlock. - Constant *RetainBlock = nullptr; + Function *RetainBlock = nullptr; /// Declaration for ObjC runtime function objc_autorelease. - Constant *Autorelease = nullptr; + Function *Autorelease = nullptr; /// Declaration for objc_storeStrong(). - Constant *StoreStrong = nullptr; + Function *StoreStrong = nullptr; /// Declaration for objc_retainAutoreleasedReturnValue(). - Constant *RetainRV = nullptr; + Function *RetainRV = nullptr; /// Declaration for objc_retainAutorelease(). - Constant *RetainAutorelease = nullptr; + Function *RetainAutorelease = nullptr; /// Declaration for objc_retainAutoreleaseReturnValue(). - Constant *RetainAutoreleaseRV = nullptr; + Function *RetainAutoreleaseRV = nullptr; - Constant *getIntrinsicEntryPoint(Constant *&Decl, Intrinsic::ID IntID) { + Function *getIntrinsicEntryPoint(Function *&Decl, Intrinsic::ID IntID) { if (Decl) return Decl; diff --git a/lib/Transforms/ObjCARC/BlotMapVector.h b/lib/Transforms/ObjCARC/BlotMapVector.h index 9ade14c1177a..2fa07cfb32c0 100644 --- a/lib/Transforms/ObjCARC/BlotMapVector.h +++ b/lib/Transforms/ObjCARC/BlotMapVector.h @@ -1,9 +1,8 @@ //===- BlotMapVector.h - A MapVector with the blot operation ----*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index 4bd5fd1acd4c..e8f8fb6f3a7c 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -1,9 +1,8 @@ //===- DependencyAnalysis.cpp - ObjC ARC Optimization ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.h b/lib/Transforms/ObjCARC/DependencyAnalysis.h index 0f13b02c806f..ed89c8c8fc89 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.h +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.h @@ -1,9 +1,8 @@ //===- DependencyAnalysis.h - ObjC ARC Optimization ---*- C++ -*-----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file diff --git a/lib/Transforms/ObjCARC/ObjCARC.cpp b/lib/Transforms/ObjCARC/ObjCARC.cpp index c30aaebd0f4d..f4da51650a7d 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -1,9 +1,8 @@ //===-- ObjCARC.cpp -------------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/ObjCARC/ObjCARC.h b/lib/Transforms/ObjCARC/ObjCARC.h index 751c8f30e814..d465630800b9 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.h +++ b/lib/Transforms/ObjCARC/ObjCARC.h @@ -1,9 +1,8 @@ //===- ObjCARC.h - ObjC ARC Optimization --------------*- C++ -*-----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file diff --git a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index 8d3ef8fde534..b341dd807508 100644 --- a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -1,9 +1,8 @@ //===- ObjCARCAPElim.cpp - ObjC ARC Optimization --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index abe2871c0b8f..36aa513ec554 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -1,9 +1,8 @@ //===- ObjCARCContract.cpp - ObjC ARC Optimization ------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file @@ -46,6 +45,10 @@ using namespace llvm::objcarc; STATISTIC(NumPeeps, "Number of calls peephole-optimized"); STATISTIC(NumStoreStrongs, "Number objc_storeStrong calls formed"); +static cl::opt<unsigned> MaxBBSize("arc-contract-max-bb-size", cl::Hidden, + cl::desc("Maximum basic block size to discover the dominance relation of " + "two instructions in the same basic block"), cl::init(65535)); + //===----------------------------------------------------------------------===// // Declarations //===----------------------------------------------------------------------===// @@ -140,7 +143,7 @@ bool ObjCARCContract::optimizeRetainCall(Function &F, Instruction *Retain) { // We do not have to worry about tail calls/does not throw since // retain/retainRV have the same properties. - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::RetainRV); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::RetainRV); cast<CallInst>(Retain)->setCalledFunction(Decl); LLVM_DEBUG(dbgs() << "New: " << *Retain << "\n"); @@ -189,7 +192,7 @@ bool ObjCARCContract::contractAutorelease( " Retain: " << *Retain << "\n"); - Constant *Decl = EP.get(Class == ARCInstKind::AutoreleaseRV + Function *Decl = EP.get(Class == ARCInstKind::AutoreleaseRV ? ARCRuntimeEntryPointKind::RetainAutoreleaseRV : ARCRuntimeEntryPointKind::RetainAutorelease); Retain->setCalledFunction(Decl); @@ -314,8 +317,8 @@ findRetainForStoreStrongContraction(Value *New, StoreInst *Store, /// Create a call instruction with the correct funclet token. Should be used /// instead of calling CallInst::Create directly. static CallInst * -createCallInst(Value *Func, ArrayRef<Value *> Args, const Twine &NameStr, - Instruction *InsertBefore, +createCallInst(FunctionType *FTy, Value *Func, ArrayRef<Value *> Args, + const Twine &NameStr, Instruction *InsertBefore, const DenseMap<BasicBlock *, ColorVector> &BlockColors) { SmallVector<OperandBundleDef, 1> OpBundles; if (!BlockColors.empty()) { @@ -326,7 +329,15 @@ createCallInst(Value *Func, ArrayRef<Value *> Args, const Twine &NameStr, OpBundles.emplace_back("funclet", EHPad); } - return CallInst::Create(Func, Args, OpBundles, NameStr, InsertBefore); + return CallInst::Create(FTy, Func, Args, OpBundles, NameStr, InsertBefore); +} + +static CallInst * +createCallInst(FunctionCallee Func, ArrayRef<Value *> Args, const Twine &NameStr, + Instruction *InsertBefore, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { + return createCallInst(Func.getFunctionType(), Func.getCallee(), Args, NameStr, + InsertBefore, BlockColors); } /// Attempt to merge an objc_release with a store, load, and objc_retain to form @@ -409,7 +420,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong( Args[0] = new BitCastInst(Args[0], I8XX, "", Store); if (Args[1]->getType() != I8X) Args[1] = new BitCastInst(Args[1], I8X, "", Store); - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::StoreStrong); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::StoreStrong); CallInst *StoreStrong = createCallInst(Decl, Args, "", Store, BlockColors); StoreStrong->setDoesNotThrow(); StoreStrong->setDebugLoc(Store->getDebugLoc()); @@ -432,102 +443,100 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong( } bool ObjCARCContract::tryToPeepholeInstruction( - Function &F, Instruction *Inst, inst_iterator &Iter, - SmallPtrSetImpl<Instruction *> &DependingInsts, - SmallPtrSetImpl<const BasicBlock *> &Visited, - bool &TailOkForStoreStrongs, - const DenseMap<BasicBlock *, ColorVector> &BlockColors) { - // Only these library routines return their argument. In particular, - // objc_retainBlock does not necessarily return its argument. + Function &F, Instruction *Inst, inst_iterator &Iter, + SmallPtrSetImpl<Instruction *> &DependingInsts, + SmallPtrSetImpl<const BasicBlock *> &Visited, bool &TailOkForStoreStrongs, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { + // Only these library routines return their argument. In particular, + // objc_retainBlock does not necessarily return its argument. ARCInstKind Class = GetBasicARCInstKind(Inst); - switch (Class) { - case ARCInstKind::FusedRetainAutorelease: - case ARCInstKind::FusedRetainAutoreleaseRV: + switch (Class) { + case ARCInstKind::FusedRetainAutorelease: + case ARCInstKind::FusedRetainAutoreleaseRV: + return false; + case ARCInstKind::Autorelease: + case ARCInstKind::AutoreleaseRV: + return contractAutorelease(F, Inst, Class, DependingInsts, Visited); + case ARCInstKind::Retain: + // Attempt to convert retains to retainrvs if they are next to function + // calls. + if (!optimizeRetainCall(F, Inst)) return false; - case ARCInstKind::Autorelease: - case ARCInstKind::AutoreleaseRV: - return contractAutorelease(F, Inst, Class, DependingInsts, Visited); - case ARCInstKind::Retain: - // Attempt to convert retains to retainrvs if they are next to function - // calls. - if (!optimizeRetainCall(F, Inst)) - return false; - // If we succeed in our optimization, fall through. - LLVM_FALLTHROUGH; - case ARCInstKind::RetainRV: - case ARCInstKind::ClaimRV: { - // If we're compiling for a target which needs a special inline-asm - // marker to do the return value optimization, insert it now. - if (!RVInstMarker) - return false; - BasicBlock::iterator BBI = Inst->getIterator(); - BasicBlock *InstParent = Inst->getParent(); - - // Step up to see if the call immediately precedes the RV call. - // If it's an invoke, we have to cross a block boundary. And we have - // to carefully dodge no-op instructions. - do { - if (BBI == InstParent->begin()) { - BasicBlock *Pred = InstParent->getSinglePredecessor(); - if (!Pred) - goto decline_rv_optimization; - BBI = Pred->getTerminator()->getIterator(); - break; - } - --BBI; - } while (IsNoopInstruction(&*BBI)); - - if (&*BBI == GetArgRCIdentityRoot(Inst)) { - LLVM_DEBUG(dbgs() << "Adding inline asm marker for the return value " - "optimization.\n"); - Changed = true; - InlineAsm *IA = InlineAsm::get( - FunctionType::get(Type::getVoidTy(Inst->getContext()), - /*isVarArg=*/false), - RVInstMarker->getString(), - /*Constraints=*/"", /*hasSideEffects=*/true); - - createCallInst(IA, None, "", Inst, BlockColors); - } - decline_rv_optimization: + // If we succeed in our optimization, fall through. + LLVM_FALLTHROUGH; + case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: { + // If we're compiling for a target which needs a special inline-asm + // marker to do the return value optimization, insert it now. + if (!RVInstMarker) return false; - } - case ARCInstKind::InitWeak: { - // objc_initWeak(p, null) => *p = null - CallInst *CI = cast<CallInst>(Inst); - if (IsNullOrUndef(CI->getArgOperand(1))) { - Value *Null = - ConstantPointerNull::get(cast<PointerType>(CI->getType())); - Changed = true; - new StoreInst(Null, CI->getArgOperand(0), CI); - - LLVM_DEBUG(dbgs() << "OBJCARCContract: Old = " << *CI << "\n" - << " New = " << *Null << "\n"); - - CI->replaceAllUsesWith(Null); - CI->eraseFromParent(); + BasicBlock::iterator BBI = Inst->getIterator(); + BasicBlock *InstParent = Inst->getParent(); + + // Step up to see if the call immediately precedes the RV call. + // If it's an invoke, we have to cross a block boundary. And we have + // to carefully dodge no-op instructions. + do { + if (BBI == InstParent->begin()) { + BasicBlock *Pred = InstParent->getSinglePredecessor(); + if (!Pred) + goto decline_rv_optimization; + BBI = Pred->getTerminator()->getIterator(); + break; } - return true; + --BBI; + } while (IsNoopInstruction(&*BBI)); + + if (&*BBI == GetArgRCIdentityRoot(Inst)) { + LLVM_DEBUG(dbgs() << "Adding inline asm marker for the return value " + "optimization.\n"); + Changed = true; + InlineAsm *IA = + InlineAsm::get(FunctionType::get(Type::getVoidTy(Inst->getContext()), + /*isVarArg=*/false), + RVInstMarker->getString(), + /*Constraints=*/"", /*hasSideEffects=*/true); + + createCallInst(IA, None, "", Inst, BlockColors); } - case ARCInstKind::Release: - // Try to form an objc store strong from our release. If we fail, there is - // nothing further to do below, so continue. - tryToContractReleaseIntoStoreStrong(Inst, Iter, BlockColors); - return true; - case ARCInstKind::User: - // Be conservative if the function has any alloca instructions. - // Technically we only care about escaping alloca instructions, - // but this is sufficient to handle some interesting cases. - if (isa<AllocaInst>(Inst)) - TailOkForStoreStrongs = false; - return true; - case ARCInstKind::IntrinsicUser: - // Remove calls to @llvm.objc.clang.arc.use(...). - Inst->eraseFromParent(); - return true; - default: - return true; + decline_rv_optimization: + return false; + } + case ARCInstKind::InitWeak: { + // objc_initWeak(p, null) => *p = null + CallInst *CI = cast<CallInst>(Inst); + if (IsNullOrUndef(CI->getArgOperand(1))) { + Value *Null = ConstantPointerNull::get(cast<PointerType>(CI->getType())); + Changed = true; + new StoreInst(Null, CI->getArgOperand(0), CI); + + LLVM_DEBUG(dbgs() << "OBJCARCContract: Old = " << *CI << "\n" + << " New = " << *Null << "\n"); + + CI->replaceAllUsesWith(Null); + CI->eraseFromParent(); } + return true; + } + case ARCInstKind::Release: + // Try to form an objc store strong from our release. If we fail, there is + // nothing further to do below, so continue. + tryToContractReleaseIntoStoreStrong(Inst, Iter, BlockColors); + return true; + case ARCInstKind::User: + // Be conservative if the function has any alloca instructions. + // Technically we only care about escaping alloca instructions, + // but this is sufficient to handle some interesting cases. + if (isa<AllocaInst>(Inst)) + TailOkForStoreStrongs = false; + return true; + case ARCInstKind::IntrinsicUser: + // Remove calls to @llvm.objc.clang.arc.use(...). + Inst->eraseFromParent(); + return true; + default: + return true; + } } //===----------------------------------------------------------------------===// @@ -568,6 +577,24 @@ bool ObjCARCContract::runOnFunction(Function &F) { // reduces register pressure. SmallPtrSet<Instruction *, 4> DependingInstructions; SmallPtrSet<const BasicBlock *, 4> Visited; + + // Cache the basic block size. + DenseMap<const BasicBlock *, unsigned> BBSizeMap; + + // A lambda that lazily computes the size of a basic block and determines + // whether the size exceeds MaxBBSize. + auto IsLargeBB = [&](const BasicBlock *BB) { + unsigned BBSize; + auto I = BBSizeMap.find(BB); + + if (I != BBSizeMap.end()) + BBSize = I->second; + else + BBSize = BBSizeMap[BB] = BB->size(); + + return BBSize > MaxBBSize; + }; + for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E;) { Instruction *Inst = &*I++; @@ -585,7 +612,7 @@ bool ObjCARCContract::runOnFunction(Function &F) { // and such; to do the replacement, the argument must have type i8*. // Function for replacing uses of Arg dominated by Inst. - auto ReplaceArgUses = [Inst, this](Value *Arg) { + auto ReplaceArgUses = [Inst, IsLargeBB, this](Value *Arg) { // If we're compiling bugpointed code, don't get in trouble. if (!isa<Instruction>(Arg) && !isa<Argument>(Arg)) return; @@ -597,6 +624,17 @@ bool ObjCARCContract::runOnFunction(Function &F) { Use &U = *UI++; unsigned OperandNo = U.getOperandNo(); + // Don't replace the uses if Inst and the user belong to the same basic + // block and the size of the basic block is large. We don't want to call + // DominatorTree::dominate in that case. We can remove this check if we + // can use OrderedBasicBlock to compute the dominance relation between + // two instructions, but that's not currently possible since it doesn't + // recompute the instruction ordering when new instructions are inserted + // to the basic block. + if (Inst->getParent() == cast<Instruction>(U.getUser())->getParent() && + IsLargeBB(Inst->getParent())) + continue; + // If the call's return value dominates a use of the call's argument // value, rewrite the use to use the return value. We check for // reachability here because an unreachable call is considered to @@ -737,15 +775,8 @@ bool ObjCARCContract::doInitialization(Module &M) { EP.init(&M); // Initialize RVInstMarker. - RVInstMarker = nullptr; - if (NamedMDNode *NMD = - M.getNamedMetadata("clang.arc.retainAutoreleasedReturnValueMarker")) - if (NMD->getNumOperands() == 1) { - const MDNode *N = NMD->getOperand(0); - if (N->getNumOperands() == 1) - if (const MDString *S = dyn_cast<MDString>(N->getOperand(0))) - RVInstMarker = S; - } + const char *MarkerKey = "clang.arc.retainAutoreleasedReturnValueMarker"; + RVInstMarker = dyn_cast_or_null<MDString>(M.getModuleFlag(MarkerKey)); return false; } diff --git a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index 6a345ef56e1b..04e98d8f5577 100644 --- a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -1,9 +1,8 @@ //===- ObjCARCExpand.cpp - ObjC ARC Optimization --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 9a02174556fc..6653ff0bb91a 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -1,9 +1,8 @@ //===- ObjCARCOpts.cpp - ObjC ARC Optimization ----------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -74,6 +73,11 @@ using namespace llvm::objcarc; #define DEBUG_TYPE "objc-arc-opts" +static cl::opt<unsigned> MaxPtrStates("arc-opt-max-ptr-states", + cl::Hidden, + cl::desc("Maximum number of ptr states the optimizer keeps track of"), + cl::init(4095)); + /// \defgroup ARCUtilities Utility declarations/definitions specific to ARC. /// @{ @@ -220,6 +224,10 @@ namespace { return !PerPtrTopDown.empty(); } + unsigned top_down_ptr_list_size() const { + return std::distance(top_down_ptr_begin(), top_down_ptr_end()); + } + using bottom_up_ptr_iterator = decltype(PerPtrBottomUp)::iterator; using const_bottom_up_ptr_iterator = decltype(PerPtrBottomUp)::const_iterator; @@ -238,6 +246,10 @@ namespace { return !PerPtrBottomUp.empty(); } + unsigned bottom_up_ptr_list_size() const { + return std::distance(bottom_up_ptr_begin(), bottom_up_ptr_end()); + } + /// Mark this block as being an entry block, which has one path from the /// entry by definition. void SetAsEntry() { TopDownPathCount = 1; } @@ -481,6 +493,10 @@ namespace { /// A flag indicating whether this optimization pass should run. bool Run; + /// A flag indicating whether the optimization that removes or moves + /// retain/release pairs should be performed. + bool DisableRetainReleasePairing = false; + /// Flags which determine whether each of the interesting runtime functions /// is in fact used in the current function. unsigned UsedInThisFunction; @@ -642,7 +658,7 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { "Old = " << *RetainRV << "\n"); - Constant *NewDecl = EP.get(ARCRuntimeEntryPointKind::Retain); + Function *NewDecl = EP.get(ARCRuntimeEntryPointKind::Retain); cast<CallInst>(RetainRV)->setCalledFunction(NewDecl); LLVM_DEBUG(dbgs() << "New = " << *RetainRV << "\n"); @@ -691,7 +707,7 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, << *AutoreleaseRV << "\n"); CallInst *AutoreleaseRVCI = cast<CallInst>(AutoreleaseRV); - Constant *NewDecl = EP.get(ARCRuntimeEntryPointKind::Autorelease); + Function *NewDecl = EP.get(ARCRuntimeEntryPointKind::Autorelease); AutoreleaseRVCI->setCalledFunction(NewDecl); AutoreleaseRVCI->setTailCall(false); // Never tail call objc_autorelease. Class = ARCInstKind::Autorelease; @@ -744,6 +760,19 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { LLVM_DEBUG(dbgs() << "Visiting: Class: " << Class << "; " << *Inst << "\n"); + // Some of the ARC calls can be deleted if their arguments are global + // variables that are inert in ARC. + if (IsNoopOnGlobal(Class)) { + Value *Opnd = Inst->getOperand(0); + if (auto *GV = dyn_cast<GlobalVariable>(Opnd->stripPointerCasts())) + if (GV->hasAttribute("objc_arc_inert")) { + if (!Inst->getType()->isVoidTy()) + Inst->replaceAllUsesWith(Opnd); + Inst->eraseFromParent(); + continue; + } + } + switch (Class) { default: break; @@ -830,7 +859,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // Create the declaration lazily. LLVMContext &C = Inst->getContext(); - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::Release); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::Release); CallInst *NewCall = CallInst::Create(Decl, Call->getArgOperand(0), "", Call); NewCall->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), @@ -849,7 +878,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // For functions which can never be passed stack arguments, add // a tail keyword. - if (IsAlwaysTail(Class)) { + if (IsAlwaysTail(Class) && !cast<CallInst>(Inst)->isNoTailCall()) { Changed = true; LLVM_DEBUG( dbgs() << "Adding tail keyword to function since it can never be " @@ -1273,6 +1302,13 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, LLVM_DEBUG(dbgs() << " Visiting " << *Inst << "\n"); NestingDetected |= VisitInstructionBottomUp(Inst, BB, Retains, MyStates); + + // Bail out if the number of pointers being tracked becomes too large so + // that this pass can complete in a reasonable amount of time. + if (MyStates.bottom_up_ptr_list_size() > MaxPtrStates) { + DisableRetainReleasePairing = true; + return false; + } } // If there's a predecessor with an invoke, visit the invoke as if it were @@ -1395,6 +1431,13 @@ ObjCARCOpt::VisitTopDown(BasicBlock *BB, LLVM_DEBUG(dbgs() << " Visiting " << Inst << "\n"); NestingDetected |= VisitInstructionTopDown(&Inst, Releases, MyStates); + + // Bail out if the number of pointers being tracked becomes too large so + // that this pass can complete in a reasonable amount of time. + if (MyStates.top_down_ptr_list_size() > MaxPtrStates) { + DisableRetainReleasePairing = true; + return false; + } } LLVM_DEBUG(dbgs() << "\nState Before Checking for CFG Hazards:\n" @@ -1501,13 +1544,19 @@ bool ObjCARCOpt::Visit(Function &F, // Use reverse-postorder on the reverse CFG for bottom-up. bool BottomUpNestingDetected = false; - for (BasicBlock *BB : llvm::reverse(ReverseCFGPostOrder)) + for (BasicBlock *BB : llvm::reverse(ReverseCFGPostOrder)) { BottomUpNestingDetected |= VisitBottomUp(BB, BBStates, Retains); + if (DisableRetainReleasePairing) + return false; + } // Use reverse-postorder for top-down. bool TopDownNestingDetected = false; - for (BasicBlock *BB : llvm::reverse(PostOrder)) + for (BasicBlock *BB : llvm::reverse(PostOrder)) { TopDownNestingDetected |= VisitTopDown(BB, BBStates, Releases); + if (DisableRetainReleasePairing) + return false; + } return TopDownNestingDetected && BottomUpNestingDetected; } @@ -1528,7 +1577,7 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, for (Instruction *InsertPt : ReleasesToMove.ReverseInsertPts) { Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); CallInst *Call = CallInst::Create(Decl, MyArg, "", InsertPt); Call->setDoesNotThrow(); Call->setTailCall(); @@ -1541,7 +1590,7 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, for (Instruction *InsertPt : RetainsToMove.ReverseInsertPts) { Value *MyArg = ArgTy == ParamTy ? Arg : new BitCastInst(Arg, ParamTy, "", InsertPt); - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::Release); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::Release); CallInst *Call = CallInst::Create(Decl, MyArg, "", InsertPt); // Attach a clang.imprecise_release metadata tag, if appropriate. if (MDNode *M = ReleasesToMove.ReleaseMetadata) @@ -1877,7 +1926,7 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { Changed = true; // If the load has a builtin retain, insert a plain retain for it. if (Class == ARCInstKind::LoadWeakRetained) { - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); CallInst *CI = CallInst::Create(Decl, EarlierCall, "", Call); CI->setTailCall(); } @@ -1906,7 +1955,7 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { Changed = true; // If the load has a builtin retain, insert a plain retain for it. if (Class == ARCInstKind::LoadWeakRetained) { - Constant *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); + Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); CallInst *CI = CallInst::Create(Decl, EarlierCall, "", Call); CI->setTailCall(); } @@ -2003,6 +2052,9 @@ bool ObjCARCOpt::OptimizeSequences(Function &F) { // Analyze the CFG of the function, and all instructions. bool NestingDetected = Visit(F, BBStates, Retains, Releases); + if (DisableRetainReleasePairing) + return false; + // Transform. bool AnyPairsCompletelyEliminated = PerformCodePlacement(BBStates, Retains, Releases, diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 3004fffb9745..c6138edba95a 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -1,9 +1,8 @@ //===- ProvenanceAnalysis.cpp - ObjC ARC Optimization ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index 1276f564a022..8fd842fd42d6 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -1,9 +1,8 @@ //===- ProvenanceAnalysis.h - ObjC ARC Optimization -------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp index 870a5f600fd8..b768f7973b87 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysisEvaluator.cpp @@ -1,9 +1,8 @@ //===- ProvenanceAnalysisEvaluator.cpp - ObjC ARC Optimization ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/ObjCARC/PtrState.cpp b/lib/Transforms/ObjCARC/PtrState.cpp index 8a7b6a74fae2..3243481dee0d 100644 --- a/lib/Transforms/ObjCARC/PtrState.cpp +++ b/lib/Transforms/ObjCARC/PtrState.cpp @@ -1,9 +1,8 @@ //===- PtrState.cpp -------------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/ObjCARC/PtrState.h b/lib/Transforms/ObjCARC/PtrState.h index f5b9b853d8e3..66614c06cb79 100644 --- a/lib/Transforms/ObjCARC/PtrState.h +++ b/lib/Transforms/ObjCARC/PtrState.h @@ -1,9 +1,8 @@ //===- PtrState.h - ARC State for a Ptr -------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index b0602d96798c..7f7460c5746a 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -1,9 +1,8 @@ //===- ADCE.cpp - Code to perform dead code elimination -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -20,9 +19,11 @@ #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" #include "llvm/Analysis/PostDominators.h" @@ -30,7 +31,6 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -136,7 +136,7 @@ class AggressiveDeadCodeElimination { SmallPtrSet<const Metadata *, 32> AliveScopes; /// Set of blocks with not known to have live terminators. - SmallPtrSet<BasicBlock *, 16> BlocksWithDeadTerminators; + SmallSetVector<BasicBlock *, 16> BlocksWithDeadTerminators; /// The set of blocks which we have determined whose control /// dependence sources must be live and which have not had @@ -390,7 +390,7 @@ void AggressiveDeadCodeElimination::markLive(Instruction *I) { // Mark the containing block live auto &BBInfo = *Info.Block; if (BBInfo.Terminator == I) { - BlocksWithDeadTerminators.erase(BBInfo.BB); + BlocksWithDeadTerminators.remove(BBInfo.BB); // For live terminators, mark destination blocks // live to preserve this control flow edges. if (!BBInfo.UnconditionalBranch) @@ -479,10 +479,14 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { // which currently have dead terminators that are control // dependence sources of a block which is in NewLiveBlocks. + const SmallPtrSet<BasicBlock *, 16> BWDT{ + BlocksWithDeadTerminators.begin(), + BlocksWithDeadTerminators.end() + }; SmallVector<BasicBlock *, 32> IDFBlocks; ReverseIDFCalculator IDFs(PDT); IDFs.setDefiningBlocks(NewLiveBlocks); - IDFs.setLiveInBlocks(BlocksWithDeadTerminators); + IDFs.setLiveInBlocks(BWDT); IDFs.calculate(IDFBlocks); NewLiveBlocks.clear(); diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 0830ff5dd042..de9a62e88c27 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -1,10 +1,9 @@ //===----------------------- AlignmentFromAssumptions.cpp -----------------===// // Set Load/Store Alignments From Assumptions // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index d3c9b9a270aa..9bd387c33e80 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -1,9 +1,8 @@ //===---- BDCE.cpp - Bit-tracking dead code elimination -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -42,14 +41,17 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { "Trivializing a non-integer value?"); // Initialize the worklist with eligible direct users. + SmallPtrSet<Instruction *, 16> Visited; SmallVector<Instruction *, 16> WorkList; for (User *JU : I->users()) { // If all bits of a user are demanded, then we know that nothing below that // in the def-use chain needs to be changed. auto *J = dyn_cast<Instruction>(JU); if (J && J->getType()->isIntOrIntVectorTy() && - !DB.getDemandedBits(J).isAllOnesValue()) + !DB.getDemandedBits(J).isAllOnesValue()) { + Visited.insert(J); WorkList.push_back(J); + } // Note that we need to check for non-int types above before asking for // demanded bits. Normally, the only way to reach an instruction with an @@ -62,7 +64,6 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { } // DFS through subsequent users while tracking visits to avoid cycles. - SmallPtrSet<Instruction *, 16> Visited; while (!WorkList.empty()) { Instruction *J = WorkList.pop_back_val(); @@ -73,13 +74,11 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // 1. llvm.assume demands its operand, so trivializing can't change it. // 2. range metadata only applies to memory accesses which demand all bits. - Visited.insert(J); - for (User *KU : J->users()) { // If all bits of a user are demanded, then we know that nothing below // that in the def-use chain needs to be changed. auto *K = dyn_cast<Instruction>(KU); - if (K && !Visited.count(K) && K->getType()->isIntOrIntVectorTy() && + if (K && Visited.insert(K).second && K->getType()->isIntOrIntVectorTy() && !DB.getDemandedBits(K).isAllOnesValue()) WorkList.push_back(K); } diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp index a806d6faed60..3519b000a33f 100644 --- a/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -1,9 +1,8 @@ //===- CallSiteSplitting.cpp ----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -184,6 +183,9 @@ static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { } static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { + if (CS.isConvergent() || CS.cannotDuplicate()) + return false; + // FIXME: As of now we handle only CallInst. InvokeInst could be handled // without too much effort. Instruction *Instr = CS.getInstruction(); @@ -367,7 +369,7 @@ static void splitCallSite( assert(Splits.size() == 2 && "Expected exactly 2 splits!"); for (unsigned i = 0; i < Splits.size(); i++) { Splits[i]->getTerminator()->eraseFromParent(); - DTU.deleteEdge(Splits[i], TailBB); + DTU.applyUpdatesPermissive({{DominatorTree::Delete, Splits[i], TailBB}}); } // Erase the tail block once done with musttail patching diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index beac0d967a98..98243a23f1ef 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -1,9 +1,8 @@ //===- ConstantHoisting.cpp - Prepare code for expensive constants --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -42,6 +41,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" @@ -61,6 +61,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/SizeOpts.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -112,11 +113,10 @@ public: if (ConstHoistWithBlockFrequency) AU.addRequired<BlockFrequencyInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } - void releaseMemory() override { Impl.releaseMemory(); } - private: ConstantHoistingPass Impl; }; @@ -129,6 +129,7 @@ INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", "Constant Hoisting", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist", "Constant Hoisting", false, false) @@ -151,7 +152,8 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { ConstHoistWithBlockFrequency ? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI() : nullptr, - Fn.getEntryBlock()); + Fn.getEntryBlock(), + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI()); if (MadeChange) { LLVM_DEBUG(dbgs() << "********** Function after Constant Hoisting: " @@ -211,6 +213,9 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, // in the dominator tree from Entry to 'BB'. SmallPtrSet<BasicBlock *, 16> Candidates; for (auto BB : BBs) { + // Ignore unreachable basic blocks. + if (!DT.isReachableFromEntry(BB)) + continue; Path.clear(); // Walk up the dominator tree until Entry or another BB in BBs // is reached. Insert the nodes on the way to the Path. @@ -548,7 +553,9 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, ConstCandVecType::iterator &MaxCostItr) { unsigned NumUses = 0; - if(!Entry->getParent()->optForSize() || std::distance(S,E) > 100) { + bool OptForSize = Entry->getParent()->hasOptSize() || + llvm::shouldOptimizeForSize(Entry->getParent(), PSI, BFI); + if (!OptForSize || std::distance(S,E) > 100) { for (auto ConstCand = S; ConstCand != E; ++ConstCand) { NumUses += ConstCand->Uses.size(); if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost) @@ -640,8 +647,8 @@ void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) { ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; // Sort the constants by value and type. This invalidates the mapping! - std::stable_sort(ConstCandVec.begin(), ConstCandVec.end(), - [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { + llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS, + const ConstantCandidate &RHS) { if (LHS.ConstInt->getType() != RHS.ConstInt->getType()) return LHS.ConstInt->getType()->getBitWidth() < RHS.ConstInt->getType()->getBitWidth(); @@ -824,7 +831,9 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; for (auto const &ConstInfo : ConstInfoVec) { SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo); - assert(!IPSet.empty() && "IPSet is empty"); + // We can have an empty set if the function contains unreachable blocks. + if (IPSet.empty()) + continue; unsigned UsesNum = 0; unsigned ReBasesNum = 0; @@ -917,13 +926,14 @@ void ConstantHoistingPass::deleteDeadCastInst() const { /// Optimize expensive integer constants in the given function. bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, DominatorTree &DT, BlockFrequencyInfo *BFI, - BasicBlock &Entry) { + BasicBlock &Entry, ProfileSummaryInfo *PSI) { this->TTI = &TTI; this->DT = &DT; this->BFI = BFI; this->DL = &Fn.getParent()->getDataLayout(); this->Ctx = &Fn.getContext(); this->Entry = &Entry; + this->PSI = PSI; // Collect all constant candidates. collectConstantCandidates(Fn); @@ -948,6 +958,8 @@ bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, // Cleanup dead instructions. deleteDeadCastInst(); + cleanup(); + return MadeChange; } @@ -958,7 +970,9 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F, auto BFI = ConstHoistWithBlockFrequency ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; - if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock())) + auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock(), PSI)) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp index 51032b0625f8..770321c740a0 100644 --- a/lib/Transforms/Scalar/ConstantProp.cpp +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -1,9 +1,8 @@ //===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index d0105701c73f..89497177524f 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -1,9 +1,8 @@ //===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -16,6 +15,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" @@ -27,7 +27,6 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" @@ -64,8 +63,10 @@ STATISTIC(NumUDivs, "Number of udivs whose width was decreased"); STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumOverflows, "Number of overflow checks removed"); +STATISTIC(NumSaturating, + "Number of saturating arithmetics converted to normal arithmetics"); -static cl::opt<bool> DontProcessAdds("cvp-dont-process-adds", cl::init(true)); +static cl::opt<bool> DontAddNoWrapFlags("cvp-dont-add-nowrap-flags", cl::init(false)); namespace { @@ -307,11 +308,11 @@ static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { /// that cannot fire no matter what the incoming edge can safely be removed. If /// a case fires on every incoming edge then the entire switch can be removed /// and replaced with a branch to the case destination. -static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, +static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, DominatorTree *DT) { DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); - Value *Cond = SI->getCondition(); - BasicBlock *BB = SI->getParent(); + Value *Cond = I->getCondition(); + BasicBlock *BB = I->getParent(); // If the condition was defined in same block as the switch then LazyValueInfo // currently won't say anything useful about it, though in theory it could. @@ -328,67 +329,72 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, for (auto *Succ : successors(BB)) SuccessorsCount[Succ]++; - for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { - ConstantInt *Case = CI->getCaseValue(); - - // Check to see if the switch condition is equal to/not equal to the case - // value on every incoming edge, equal/not equal being the same each time. - LazyValueInfo::Tristate State = LazyValueInfo::Unknown; - for (pred_iterator PI = PB; PI != PE; ++PI) { - // Is the switch condition equal to the case value? - LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ, - Cond, Case, *PI, - BB, SI); - // Give up on this case if nothing is known. - if (Value == LazyValueInfo::Unknown) { - State = LazyValueInfo::Unknown; - break; + { // Scope for SwitchInstProfUpdateWrapper. It must not live during + // ConstantFoldTerminator() as the underlying SwitchInst can be changed. + SwitchInstProfUpdateWrapper SI(*I); + + for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { + ConstantInt *Case = CI->getCaseValue(); + + // Check to see if the switch condition is equal to/not equal to the case + // value on every incoming edge, equal/not equal being the same each time. + LazyValueInfo::Tristate State = LazyValueInfo::Unknown; + for (pred_iterator PI = PB; PI != PE; ++PI) { + // Is the switch condition equal to the case value? + LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ, + Cond, Case, *PI, + BB, SI); + // Give up on this case if nothing is known. + if (Value == LazyValueInfo::Unknown) { + State = LazyValueInfo::Unknown; + break; + } + + // If this was the first edge to be visited, record that all other edges + // need to give the same result. + if (PI == PB) { + State = Value; + continue; + } + + // If this case is known to fire for some edges and known not to fire for + // others then there is nothing we can do - give up. + if (Value != State) { + State = LazyValueInfo::Unknown; + break; + } } - // If this was the first edge to be visited, record that all other edges - // need to give the same result. - if (PI == PB) { - State = Value; + if (State == LazyValueInfo::False) { + // This case never fires - remove it. + BasicBlock *Succ = CI->getCaseSuccessor(); + Succ->removePredecessor(BB); + CI = SI.removeCase(CI); + CE = SI->case_end(); + + // The condition can be modified by removePredecessor's PHI simplification + // logic. + Cond = SI->getCondition(); + + ++NumDeadCases; + Changed = true; + if (--SuccessorsCount[Succ] == 0) + DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, Succ}}); continue; } - - // If this case is known to fire for some edges and known not to fire for - // others then there is nothing we can do - give up. - if (Value != State) { - State = LazyValueInfo::Unknown; + if (State == LazyValueInfo::True) { + // This case always fires. Arrange for the switch to be turned into an + // unconditional branch by replacing the switch condition with the case + // value. + SI->setCondition(Case); + NumDeadCases += SI->getNumCases(); + Changed = true; break; } - } - if (State == LazyValueInfo::False) { - // This case never fires - remove it. - BasicBlock *Succ = CI->getCaseSuccessor(); - Succ->removePredecessor(BB); - CI = SI->removeCase(CI); - CE = SI->case_end(); - - // The condition can be modified by removePredecessor's PHI simplification - // logic. - Cond = SI->getCondition(); - - ++NumDeadCases; - Changed = true; - if (--SuccessorsCount[Succ] == 0) - DTU.deleteEdge(BB, Succ); - continue; - } - if (State == LazyValueInfo::True) { - // This case always fires. Arrange for the switch to be turned into an - // unconditional branch by replacing the switch condition with the case - // value. - SI->setCondition(Case); - NumDeadCases += SI->getNumCases(); - Changed = true; - break; + // Increment the case iterator since we didn't delete it. + ++CI; } - - // Increment the case iterator since we didn't delete it. - ++CI; } if (Changed) @@ -399,56 +405,48 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, return Changed; } -// See if we can prove that the given overflow intrinsic will not overflow. -static bool willNotOverflow(IntrinsicInst *II, LazyValueInfo *LVI) { - using OBO = OverflowingBinaryOperator; - auto NoWrap = [&] (Instruction::BinaryOps BinOp, unsigned NoWrapKind) { - Value *RHS = II->getOperand(1); - ConstantRange RRange = LVI->getConstantRange(RHS, II->getParent(), II); - ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( - BinOp, RRange, NoWrapKind); - // As an optimization, do not compute LRange if we do not need it. - if (NWRegion.isEmptySet()) - return false; - Value *LHS = II->getOperand(0); - ConstantRange LRange = LVI->getConstantRange(LHS, II->getParent(), II); - return NWRegion.contains(LRange); - }; - switch (II->getIntrinsicID()) { - default: - break; - case Intrinsic::uadd_with_overflow: - return NoWrap(Instruction::Add, OBO::NoUnsignedWrap); - case Intrinsic::sadd_with_overflow: - return NoWrap(Instruction::Add, OBO::NoSignedWrap); - case Intrinsic::usub_with_overflow: - return NoWrap(Instruction::Sub, OBO::NoUnsignedWrap); - case Intrinsic::ssub_with_overflow: - return NoWrap(Instruction::Sub, OBO::NoSignedWrap); - } - return false; +// See if we can prove that the given binary op intrinsic will not overflow. +static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) { + ConstantRange LRange = LVI->getConstantRange( + BO->getLHS(), BO->getParent(), BO); + ConstantRange RRange = LVI->getConstantRange( + BO->getRHS(), BO->getParent(), BO); + ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + BO->getBinaryOp(), RRange, BO->getNoWrapKind()); + return NWRegion.contains(LRange); } -static void processOverflowIntrinsic(IntrinsicInst *II) { - IRBuilder<> B(II); - Value *NewOp = nullptr; - switch (II->getIntrinsicID()) { - default: - llvm_unreachable("Unexpected instruction."); - case Intrinsic::uadd_with_overflow: - case Intrinsic::sadd_with_overflow: - NewOp = B.CreateAdd(II->getOperand(0), II->getOperand(1), II->getName()); - break; - case Intrinsic::usub_with_overflow: - case Intrinsic::ssub_with_overflow: - NewOp = B.CreateSub(II->getOperand(0), II->getOperand(1), II->getName()); - break; +static void processOverflowIntrinsic(WithOverflowInst *WO) { + IRBuilder<> B(WO); + Value *NewOp = B.CreateBinOp( + WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), WO->getName()); + // Constant-folding could have happened. + if (auto *Inst = dyn_cast<Instruction>(NewOp)) { + if (WO->isSigned()) + Inst->setHasNoSignedWrap(); + else + Inst->setHasNoUnsignedWrap(); } + + Value *NewI = B.CreateInsertValue(UndefValue::get(WO->getType()), NewOp, 0); + NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(WO->getContext()), 1); + WO->replaceAllUsesWith(NewI); + WO->eraseFromParent(); ++NumOverflows; - Value *NewI = B.CreateInsertValue(UndefValue::get(II->getType()), NewOp, 0); - NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(II->getContext()), 1); - II->replaceAllUsesWith(NewI); - II->eraseFromParent(); +} + +static void processSaturatingInst(SaturatingInst *SI) { + BinaryOperator *BinOp = BinaryOperator::Create( + SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI); + BinOp->setDebugLoc(SI->getDebugLoc()); + if (SI->isSigned()) + BinOp->setHasNoSignedWrap(); + else + BinOp->setHasNoUnsignedWrap(); + + SI->replaceAllUsesWith(BinOp); + SI->eraseFromParent(); + ++NumSaturating; } /// Infer nonnull attributes for the arguments at the specified callsite. @@ -456,13 +454,44 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; - if (auto *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { - if (willNotOverflow(II, LVI)) { - processOverflowIntrinsic(II); + if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) { + if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) { + processOverflowIntrinsic(WO); + return true; + } + } + + if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) { + if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) { + processSaturatingInst(SI); return true; } } + // Deopt bundle operands are intended to capture state with minimal + // perturbance of the code otherwise. If we can find a constant value for + // any such operand and remove a use of the original value, that's + // desireable since it may allow further optimization of that value (e.g. via + // single use rules in instcombine). Since deopt uses tend to, + // idiomatically, appear along rare conditional paths, it's reasonable likely + // we may have a conditional fact with which LVI can fold. + if (auto DeoptBundle = CS.getOperandBundle(LLVMContext::OB_deopt)) { + bool Progress = false; + for (const Use &ConstU : DeoptBundle->Inputs) { + Use &U = const_cast<Use&>(ConstU); + Value *V = U.get(); + if (V->getType()->isVectorTy()) continue; + if (isa<Constant>(V)) continue; + + Constant *C = LVI->getConstant(V, CS.getParent(), CS.getInstruction()); + if (!C) continue; + U.set(C); + Progress = true; + } + if (Progress) + return true; + } + for (Value *V : CS.args()) { PointerType *Type = dyn_cast<PointerType>(V->getType()); // Try to mark pointer typed parameters as non-null. We skip the @@ -512,7 +541,7 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. auto OrigWidth = Instr->getType()->getIntegerBitWidth(); - ConstantRange OperandRange(OrigWidth, /*isFullset=*/false); + ConstantRange OperandRange(OrigWidth, /*isFullSet=*/false); for (Value *Operand : Instr->operands()) { OperandRange = OperandRange.unionWith( LVI->getConstantRange(Operand, Instr->getParent())); @@ -603,55 +632,42 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { return true; } -static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { +static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { using OBO = OverflowingBinaryOperator; - if (DontProcessAdds) + if (DontAddNoWrapFlags) return false; - if (AddOp->getType()->isVectorTy()) + if (BinOp->getType()->isVectorTy()) return false; - bool NSW = AddOp->hasNoSignedWrap(); - bool NUW = AddOp->hasNoUnsignedWrap(); + bool NSW = BinOp->hasNoSignedWrap(); + bool NUW = BinOp->hasNoUnsignedWrap(); if (NSW && NUW) return false; - BasicBlock *BB = AddOp->getParent(); + BasicBlock *BB = BinOp->getParent(); - Value *LHS = AddOp->getOperand(0); - Value *RHS = AddOp->getOperand(1); + Value *LHS = BinOp->getOperand(0); + Value *RHS = BinOp->getOperand(1); - ConstantRange LRange = LVI->getConstantRange(LHS, BB, AddOp); - - // Initialize RRange only if we need it. If we know that guaranteed no wrap - // range for the given LHS range is empty don't spend time calculating the - // range for the RHS. - Optional<ConstantRange> RRange; - auto LazyRRange = [&] () { - if (!RRange) - RRange = LVI->getConstantRange(RHS, BB, AddOp); - return RRange.getValue(); - }; + ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp); + ConstantRange RRange = LVI->getConstantRange(RHS, BB, BinOp); bool Changed = false; if (!NUW) { ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( - BinaryOperator::Add, LRange, OBO::NoUnsignedWrap); - if (!NUWRange.isEmptySet()) { - bool NewNUW = NUWRange.contains(LazyRRange()); - AddOp->setHasNoUnsignedWrap(NewNUW); - Changed |= NewNUW; - } + BinOp->getOpcode(), RRange, OBO::NoUnsignedWrap); + bool NewNUW = NUWRange.contains(LRange); + BinOp->setHasNoUnsignedWrap(NewNUW); + Changed |= NewNUW; } if (!NSW) { ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( - BinaryOperator::Add, LRange, OBO::NoSignedWrap); - if (!NSWRange.isEmptySet()) { - bool NewNSW = NSWRange.contains(LazyRRange()); - AddOp->setHasNoSignedWrap(NewNSW); - Changed |= NewNSW; - } + BinOp->getOpcode(), RRange, OBO::NoSignedWrap); + bool NewNSW = NSWRange.contains(LRange); + BinOp->setHasNoSignedWrap(NewNSW); + Changed |= NewNSW; } return Changed; @@ -725,7 +741,8 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); break; case Instruction::Add: - BBChanged |= processAdd(cast<BinaryOperator>(II), LVI); + case Instruction::Sub: + BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI); break; } } diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index 4c964e6e888c..479e0ed74074 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -1,9 +1,8 @@ //===- DCE.cpp - Code to perform dead code elimination --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 469930ca6a19..a81645745b48 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -1,9 +1,8 @@ //===- DeadStoreElimination.cpp - Fast Dead Store Elimination -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -29,8 +28,8 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -57,6 +56,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -98,9 +98,8 @@ using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>; static void deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, - InstOverlapIntervalsTy &IOL, - DenseMap<Instruction*, size_t> *InstrOrdering, - SmallSetVector<Value *, 16> *ValueSet = nullptr) { + InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, + SmallSetVector<const Value *, 16> *ValueSet = nullptr) { SmallVector<Instruction*, 32> NowDeadInsts; NowDeadInsts.push_back(I); @@ -136,8 +135,8 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, } if (ValueSet) ValueSet->remove(DeadInst); - InstrOrdering->erase(DeadInst); IOL.erase(DeadInst); + OBB.eraseInstruction(DeadInst); if (NewIter == DeadInst->getIterator()) NewIter = DeadInst->eraseFromParent(); @@ -657,8 +656,7 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, static bool handleFree(CallInst *F, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, const TargetLibraryInfo *TLI, - InstOverlapIntervalsTy &IOL, - DenseMap<Instruction*, size_t> *InstrOrdering) { + InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB) { bool MadeChange = false; MemoryLocation Loc = MemoryLocation(F->getOperand(0)); @@ -692,7 +690,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, // DCE instructions only used to calculate that store. BasicBlock::iterator BBI(Dependency); - deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, InstrOrdering); + deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, OBB); ++NumFastStores; MadeChange = true; @@ -715,7 +713,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, /// the DeadStackObjects set. If so, they become live because the location is /// being loaded. static void removeAccessedObjects(const MemoryLocation &LoadedLoc, - SmallSetVector<Value *, 16> &DeadStackObjects, + SmallSetVector<const Value *, 16> &DeadStackObjects, const DataLayout &DL, AliasAnalysis *AA, const TargetLibraryInfo *TLI, const Function *F) { @@ -728,12 +726,12 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc, // If the kill pointer can be easily reduced to an alloca, don't bother doing // extraneous AA queries. if (isa<AllocaInst>(UnderlyingPointer) || isa<Argument>(UnderlyingPointer)) { - DeadStackObjects.remove(const_cast<Value*>(UnderlyingPointer)); + DeadStackObjects.remove(UnderlyingPointer); return; } // Remove objects that could alias LoadedLoc. - DeadStackObjects.remove_if([&](Value *I) { + DeadStackObjects.remove_if([&](const Value *I) { // See if the loaded location could alias the stack location. MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI, F)); return !AA->isNoAlias(StackLoc, LoadedLoc); @@ -747,15 +745,15 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc, /// store i32 1, i32* %A /// ret void static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, - MemoryDependenceResults *MD, - const TargetLibraryInfo *TLI, - InstOverlapIntervalsTy &IOL, - DenseMap<Instruction*, size_t> *InstrOrdering) { + MemoryDependenceResults *MD, + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + OrderedBasicBlock &OBB) { bool MadeChange = false; // Keep track of all of the stack objects that are dead at the end of the // function. - SmallSetVector<Value*, 16> DeadStackObjects; + SmallSetVector<const Value*, 16> DeadStackObjects; // Find all of the alloca'd pointers in the entry block. BasicBlock &Entry = BB.getParent()->front(); @@ -784,12 +782,12 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // If we find a store, check to see if it points into a dead stack value. if (hasAnalyzableMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) { // See through pointer-to-pointer bitcasts - SmallVector<Value *, 4> Pointers; + SmallVector<const Value *, 4> Pointers; GetUnderlyingObjects(getStoredPointerOperand(&*BBI), Pointers, DL); // Stores to stack values are valid candidates for removal. bool AllDead = true; - for (Value *Pointer : Pointers) + for (const Value *Pointer : Pointers) if (!DeadStackObjects.count(Pointer)) { AllDead = false; break; @@ -800,7 +798,8 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, LLVM_DEBUG(dbgs() << "DSE: Dead Store at End of Block:\n DEAD: " << *Dead << "\n Objects: "; - for (SmallVectorImpl<Value *>::iterator I = Pointers.begin(), + for (SmallVectorImpl<const Value *>::iterator I = + Pointers.begin(), E = Pointers.end(); I != E; ++I) { dbgs() << **I; @@ -810,7 +809,8 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, << '\n'); // DCE instructions only used to calculate that store. - deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); + deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, OBB, + &DeadStackObjects); ++NumFastStores; MadeChange = true; continue; @@ -821,7 +821,8 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, if (isInstructionTriviallyDead(&*BBI, TLI)) { LLVM_DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " << *&*BBI << '\n'); - deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); + deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, OBB, + &DeadStackObjects); ++NumFastOther; MadeChange = true; continue; @@ -847,7 +848,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // If the call might load from any of our allocas, then any store above // the call is live. - DeadStackObjects.remove_if([&](Value *I) { + DeadStackObjects.remove_if([&](const Value *I) { // See if the call site touches the value. return isRefSet(AA->getModRefInfo( Call, I, getPointerSize(I, DL, *TLI, BB.getParent()))); @@ -946,7 +947,9 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, Value *Indices[1] = { ConstantInt::get(EarlierWriteLength->getType(), OffsetMoved)}; GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + EarlierIntrinsic->getRawDest()->getType()->getPointerElementType(), EarlierIntrinsic->getRawDest(), Indices, "", EarlierWrite); + NewDestGEP->setDebugLoc(EarlierIntrinsic->getDebugLoc()); EarlierIntrinsic->setDest(NewDestGEP); EarlierOffset = EarlierOffset + OffsetMoved; } @@ -1025,7 +1028,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, const DataLayout &DL, const TargetLibraryInfo *TLI, InstOverlapIntervalsTy &IOL, - DenseMap<Instruction*, size_t> *InstrOrdering) { + OrderedBasicBlock &OBB) { // Must be a store instruction. StoreInst *SI = dyn_cast<StoreInst>(Inst); if (!SI) @@ -1041,7 +1044,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " << *DepLoad << "\n STORE: " << *SI << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB); ++NumRedundantStores; return true; } @@ -1059,7 +1062,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB); ++NumRedundantStores; return true; } @@ -1073,11 +1076,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, const DataLayout &DL = BB.getModule()->getDataLayout(); bool MadeChange = false; - // FIXME: Maybe change this to use some abstraction like OrderedBasicBlock? - // The current OrderedBasicBlock can't deal with mutation at the moment. - size_t LastThrowingInstIndex = 0; - DenseMap<Instruction*, size_t> InstrOrdering; - size_t InstrIndex = 1; + OrderedBasicBlock OBB(&BB); + Instruction *LastThrowing = nullptr; // A map of interval maps representing partially-overwritten value parts. InstOverlapIntervalsTy IOL; @@ -1086,7 +1086,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { // Handle 'free' calls specially. if (CallInst *F = isFreeCall(&*BBI, TLI)) { - MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, &InstrOrdering); + MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, OBB); // Increment BBI after handleFree has potentially deleted instructions. // This ensures we maintain a valid iterator. ++BBI; @@ -1095,10 +1095,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, Instruction *Inst = &*BBI++; - size_t CurInstNumber = InstrIndex++; - InstrOrdering.insert(std::make_pair(Inst, CurInstNumber)); if (Inst->mayThrow()) { - LastThrowingInstIndex = CurInstNumber; + LastThrowing = Inst; continue; } @@ -1107,13 +1105,13 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, continue; // eliminateNoopStore will update in iterator, if necessary. - if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, &InstrOrdering)) { + if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, OBB)) { MadeChange = true; continue; } // If we find something that writes memory, get its memory dependence. - MemDepResult InstDep = MD->getDependency(Inst); + MemDepResult InstDep = MD->getDependency(Inst, &OBB); // Ignore any store where we can't find a local dependence. // FIXME: cross-block DSE would be fun. :) @@ -1158,9 +1156,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // If the underlying object is a non-escaping memory allocation, any store // to it is dead along the unwind edge. Otherwise, we need to preserve // the store. - size_t DepIndex = InstrOrdering.lookup(DepWrite); - assert(DepIndex && "Unexpected instruction"); - if (DepIndex <= LastThrowingInstIndex) { + if (LastThrowing && OBB.dominates(DepWrite, LastThrowing)) { const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL); bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying); if (!IsStoreDeadOnUnwind) { @@ -1191,12 +1187,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, << "\n KILLER: " << *Inst << '\n'); // Delete the store and now-dead instructions that feed it. - deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, &InstrOrdering); + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB); ++NumFastStores; MadeChange = true; // We erased DepWrite; start over. - InstDep = MD->getDependency(Inst); + InstDep = MD->getDependency(Inst, &OBB); continue; } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) || ((OR == OW_Begin && @@ -1215,12 +1211,17 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, auto *Earlier = dyn_cast<StoreInst>(DepWrite); auto *Later = dyn_cast<StoreInst>(Inst); if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && + DL.typeSizeEqualsStoreSize( + Earlier->getValueOperand()->getType()) && Later && isa<ConstantInt>(Later->getValueOperand()) && + DL.typeSizeEqualsStoreSize( + Later->getValueOperand()->getType()) && memoryIsNotModifiedBetween(Earlier, Later, AA)) { // If the store we find is: // a) partially overwritten by the store to 'Loc' // b) the later store is fully contained in the earlier one and // c) they both have a constant value + // d) none of the two stores need padding // Merge the two stores, replacing the earlier store's value with a // merge of both values. // TODO: Deal with other constant types (vectors, etc), and probably @@ -1264,14 +1265,11 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, ++NumModifiedStores; // Remove earlier, wider, store - size_t Idx = InstrOrdering.lookup(DepWrite); - InstrOrdering.erase(DepWrite); - InstrOrdering.insert(std::make_pair(SI, Idx)); + OBB.replaceInstruction(DepWrite, SI); // Delete the old stores and now-dead instructions that feed them. - deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, &InstrOrdering); - deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, - &InstrOrdering); + deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, OBB); + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB); MadeChange = true; // We erased DepWrite and Inst (Loc); start over. @@ -1306,7 +1304,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // If this block ends in a return, unwind, or unreachable, all allocas are // dead at its end, which means stores to them are also dead. if (BB.getTerminator()->getNumSuccessors() == 0) - MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, &InstrOrdering); + MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, OBB); return MadeChange; } diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp index ffcf34f1cf7a..876681b4f9de 100644 --- a/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/lib/Transforms/Scalar/DivRemPairs.cpp @@ -1,9 +1,8 @@ //===- DivRemPairs.cpp - Hoist/decompose division and remainder -*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 1f09979b3382..f1f075257020 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -1,9 +1,8 @@ //===- EarlyCSE.cpp - Simple and fast CSE pass ----------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -76,6 +75,16 @@ STATISTIC(NumDSE, "Number of trivial dead stores removed"); DEBUG_COUNTER(CSECounter, "early-cse", "Controls which instructions are removed"); +static cl::opt<unsigned> EarlyCSEMssaOptCap( + "earlycse-mssa-optimization-cap", cl::init(500), cl::Hidden, + cl::desc("Enable imprecision in EarlyCSE in pathological cases, in exchange " + "for faster compile. Caps the MemorySSA clobbering calls.")); + +static cl::opt<bool> EarlyCSEDebugHash( + "earlycse-debug-hash", cl::init(false), cl::Hidden, + cl::desc("Perform extra assertion checking to verify that SimpleValue's hash " + "function is well-behaved w.r.t. its isEqual predicate")); + //===----------------------------------------------------------------------===// // SimpleValue //===----------------------------------------------------------------------===// @@ -126,7 +135,33 @@ template <> struct DenseMapInfo<SimpleValue> { } // end namespace llvm -unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { +/// Match a 'select' including an optional 'not's of the condition. +static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A, + Value *&B, + SelectPatternFlavor &Flavor) { + // Return false if V is not even a select. + if (!match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B)))) + return false; + + // Look through a 'not' of the condition operand by swapping A/B. + Value *CondNot; + if (match(Cond, m_Not(m_Value(CondNot)))) { + Cond = CondNot; + std::swap(A, B); + } + + // Set flavor if we find a match, or set it to unknown otherwise; in + // either case, return true to indicate that this is a select we can + // process. + if (auto *CmpI = dyn_cast<ICmpInst>(Cond)) + Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor; + else + Flavor = SPF_UNKNOWN; + + return true; +} + +static unsigned getHashValueImpl(SimpleValue Val) { Instruction *Inst = Val.Inst; // Hash in all of the operands as pointers. if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst)) { @@ -139,32 +174,56 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { } if (CmpInst *CI = dyn_cast<CmpInst>(Inst)) { + // Compares can be commuted by swapping the comparands and + // updating the predicate. Choose the form that has the + // comparands in sorted order, or in the case of a tie, the + // one with the lower predicate. Value *LHS = CI->getOperand(0); Value *RHS = CI->getOperand(1); CmpInst::Predicate Pred = CI->getPredicate(); - if (Inst->getOperand(0) > Inst->getOperand(1)) { + CmpInst::Predicate SwappedPred = CI->getSwappedPredicate(); + if (std::tie(LHS, Pred) > std::tie(RHS, SwappedPred)) { std::swap(LHS, RHS); - Pred = CI->getSwappedPredicate(); + Pred = SwappedPred; } return hash_combine(Inst->getOpcode(), Pred, LHS, RHS); } - // Hash min/max/abs (cmp + select) to allow for commuted operands. - // Min/max may also have non-canonical compare predicate (eg, the compare for - // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the - // compare. - Value *A, *B; - SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; - // TODO: We should also detect FP min/max. - if (SPF == SPF_SMIN || SPF == SPF_SMAX || - SPF == SPF_UMIN || SPF == SPF_UMAX) { - if (A > B) + // Hash general selects to allow matching commuted true/false operands. + SelectPatternFlavor SPF; + Value *Cond, *A, *B; + if (matchSelectWithOptionalNotCond(Inst, Cond, A, B, SPF)) { + // Hash min/max/abs (cmp + select) to allow for commuted operands. + // Min/max may also have non-canonical compare predicate (eg, the compare for + // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the + // compare. + // TODO: We should also detect FP min/max. + if (SPF == SPF_SMIN || SPF == SPF_SMAX || + SPF == SPF_UMIN || SPF == SPF_UMAX) { + if (A > B) + std::swap(A, B); + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // ABS/NABS always puts the input in A and its negation in B. + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + + // Hash general selects to allow matching commuted true/false operands. + + // If we do not have a compare as the condition, just hash in the condition. + CmpInst::Predicate Pred; + Value *X, *Y; + if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y)))) + return hash_combine(Inst->getOpcode(), Cond, A, B); + + // Similar to cmp normalization (above) - canonicalize the predicate value: + // select (icmp Pred, X, Y), A, B --> select (icmp InvPred, X, Y), B, A + if (CmpInst::getInversePredicate(Pred) < Pred) { + Pred = CmpInst::getInversePredicate(Pred); std::swap(A, B); - return hash_combine(Inst->getOpcode(), SPF, A, B); - } - if (SPF == SPF_ABS || SPF == SPF_NABS) { - // ABS/NABS always puts the input in A and its negation in B. - return hash_combine(Inst->getOpcode(), SPF, A, B); + } + return hash_combine(Inst->getOpcode(), Pred, X, Y, A, B); } if (CastInst *CI = dyn_cast<CastInst>(Inst)) @@ -179,8 +238,7 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { IVI->getOperand(1), hash_combine_range(IVI->idx_begin(), IVI->idx_end())); - assert((isa<CallInst>(Inst) || isa<BinaryOperator>(Inst) || - isa<GetElementPtrInst>(Inst) || isa<SelectInst>(Inst) || + assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) || isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst)) && "Invalid/unknown instruction"); @@ -191,7 +249,19 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); } -bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { +unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { +#ifndef NDEBUG + // If -earlycse-debug-hash was specified, return a constant -- this + // will force all hashing to collide, so we'll exhaustively search + // the table for a match, and the assertion in isEqual will fire if + // there's a bug causing equal keys to hash differently. + if (EarlyCSEDebugHash) + return 0; +#endif + return getHashValueImpl(Val); +} + +static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; if (LHS.isSentinel() || RHS.isSentinel()) @@ -227,26 +297,68 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { // Min/max/abs can occur with commuted operands, non-canonical predicates, // and/or non-canonical operands. - Value *LHSA, *LHSB; - SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor; - // TODO: We should also detect FP min/max. - if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || - LSPF == SPF_UMIN || LSPF == SPF_UMAX || - LSPF == SPF_ABS || LSPF == SPF_NABS) { - Value *RHSA, *RHSB; - SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; + // Selects can be non-trivially equivalent via inverted conditions and swaps. + SelectPatternFlavor LSPF, RSPF; + Value *CondL, *CondR, *LHSA, *RHSA, *LHSB, *RHSB; + if (matchSelectWithOptionalNotCond(LHSI, CondL, LHSA, LHSB, LSPF) && + matchSelectWithOptionalNotCond(RHSI, CondR, RHSA, RHSB, RSPF)) { if (LSPF == RSPF) { - // Abs results are placed in a defined order by matchSelectPattern. - if (LSPF == SPF_ABS || LSPF == SPF_NABS) + // TODO: We should also detect FP min/max. + if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || + LSPF == SPF_UMIN || LSPF == SPF_UMAX) + return ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA)); + + if (LSPF == SPF_ABS || LSPF == SPF_NABS) { + // Abs results are placed in a defined order by matchSelectPattern. return LHSA == RHSA && LHSB == RHSB; - return ((LHSA == RHSA && LHSB == RHSB) || - (LHSA == RHSB && LHSB == RHSA)); + } + + // select Cond, A, B <--> select not(Cond), B, A + if (CondL == CondR && LHSA == RHSA && LHSB == RHSB) + return true; + } + + // If the true/false operands are swapped and the conditions are compares + // with inverted predicates, the selects are equal: + // select (icmp Pred, X, Y), A, B <--> select (icmp InvPred, X, Y), B, A + // + // This also handles patterns with a double-negation in the sense of not + + // inverse, because we looked through a 'not' in the matching function and + // swapped A/B: + // select (cmp Pred, X, Y), A, B <--> select (not (cmp InvPred, X, Y)), B, A + // + // This intentionally does NOT handle patterns with a double-negation in + // the sense of not + not, because doing so could result in values + // comparing + // as equal that hash differently in the min/max/abs cases like: + // select (cmp slt, X, Y), X, Y <--> select (not (not (cmp slt, X, Y))), X, Y + // ^ hashes as min ^ would not hash as min + // In the context of the EarlyCSE pass, however, such cases never reach + // this code, as we simplify the double-negation before hashing the second + // select (and so still succeed at CSEing them). + if (LHSA == RHSB && LHSB == RHSA) { + CmpInst::Predicate PredL, PredR; + Value *X, *Y; + if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) && + match(CondR, m_Cmp(PredR, m_Specific(X), m_Specific(Y))) && + CmpInst::getInversePredicate(PredL) == PredR) + return true; } } return false; } +bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { + // These comparisons are nontrivial, so assert that equality implies + // hash equality (DenseMap demands this as an invariant). + bool Result = isEqualImpl(LHS, RHS); + assert(!Result || (LHS.isSentinel() && LHS.Inst == RHS.Inst) || + getHashValueImpl(LHS) == getHashValueImpl(RHS)); + return Result; +} + //===----------------------------------------------------------------------===// // CallValue //===----------------------------------------------------------------------===// @@ -419,6 +531,7 @@ public: bool run(); private: + unsigned ClobberCounter = 0; // Almost a POD, but needs to call the constructors for the scoped hash // tables so that a new scope gets pushed on. These are RAII so that the // scope gets popped when the NodeScope is destroyed. @@ -608,36 +721,11 @@ private: MSSA->verifyMemorySSA(); // Removing a store here can leave MemorySSA in an unoptimized state by // creating MemoryPhis that have identical arguments and by creating - // MemoryUses whose defining access is not an actual clobber. We handle the - // phi case eagerly here. The non-optimized MemoryUse case is lazily - // updated by MemorySSA getClobberingMemoryAccess. - if (MemoryAccess *MA = MSSA->getMemoryAccess(Inst)) { - // Optimize MemoryPhi nodes that may become redundant by having all the - // same input values once MA is removed. - SmallSetVector<MemoryPhi *, 4> PhisToCheck; - SmallVector<MemoryAccess *, 8> WorkQueue; - WorkQueue.push_back(MA); - // Process MemoryPhi nodes in FIFO order using a ever-growing vector since - // we shouldn't be processing that many phis and this will avoid an - // allocation in almost all cases. - for (unsigned I = 0; I < WorkQueue.size(); ++I) { - MemoryAccess *WI = WorkQueue[I]; - - for (auto *U : WI->users()) - if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U)) - PhisToCheck.insert(MP); - - MSSAUpdater->removeMemoryAccess(WI); - - for (MemoryPhi *MP : PhisToCheck) { - MemoryAccess *FirstIn = MP->getIncomingValue(0); - if (llvm::all_of(MP->incoming_values(), - [=](Use &In) { return In == FirstIn; })) - WorkQueue.push_back(MP); - } - PhisToCheck.clear(); - } - } + // MemoryUses whose defining access is not an actual clobber. The phi case + // is handled by MemorySSA when passing OptimizePhis = true to + // removeMemoryAccess. The non-optimized MemoryUse case is lazily updated + // by MemorySSA's getClobberingMemoryAccess. + MSSAUpdater->removeMemoryAccess(Inst, true); } }; @@ -688,8 +776,13 @@ bool EarlyCSE::isSameMemGeneration(unsigned EarlierGeneration, // LaterInst, if LaterDef dominates EarlierInst then it can't occur between // EarlierInst and LaterInst and neither can any other write that potentially // clobbers LaterInst. - MemoryAccess *LaterDef = - MSSA->getWalker()->getClobberingMemoryAccess(LaterInst); + MemoryAccess *LaterDef; + if (ClobberCounter < EarlyCSEMssaOptCap) { + LaterDef = MSSA->getWalker()->getClobberingMemoryAccess(LaterInst); + ClobberCounter++; + } else + LaterDef = LaterMA->getDefiningAccess(); + return MSSA->dominates(LaterDef, EarlierMA); } @@ -1117,7 +1210,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // At the moment, we don't remove ordered stores, but do remove // unordered atomic stores. There's no special requirement (for // unordered atomics) about removing atomic stores only in favor of - // other atomic stores since we we're going to execute the non-atomic + // other atomic stores since we were going to execute the non-atomic // one anyway and the atomic one might never have become visible. if (LastStore) { ParseMemoryInst LastStoreMemInst(LastStore, TTI); @@ -1184,8 +1277,7 @@ bool EarlyCSE::run() { CurrentGeneration, DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end())); - // Save the current generation. - unsigned LiveOutGeneration = CurrentGeneration; + assert(!CurrentGeneration && "Create a new EarlyCSE instance to rerun it."); // Process the stack. while (!nodesToProcess.empty()) { @@ -1217,9 +1309,6 @@ bool EarlyCSE::run() { } } // while (!nodes...) - // Reset the current generation. - CurrentGeneration = LiveOutGeneration; - return Changed; } diff --git a/lib/Transforms/Scalar/FlattenCFGPass.cpp b/lib/Transforms/Scalar/FlattenCFGPass.cpp index 117b19fb8a42..31670b1464e4 100644 --- a/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -1,9 +1,8 @@ //===- FlattenCFGPass.cpp - CFG Flatten Pass ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp index f2828e80bc58..4f83e869b303 100644 --- a/lib/Transforms/Scalar/Float2Int.cpp +++ b/lib/Transforms/Scalar/Float2Int.cpp @@ -1,9 +1,8 @@ //===- Float2Int.cpp - Demote floating point ops to work on integers ------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -148,10 +147,10 @@ void Float2IntPass::seen(Instruction *I, ConstantRange R) { // Helper - get a range representing a poison value. ConstantRange Float2IntPass::badRange() { - return ConstantRange(MaxIntegerBW + 1, true); + return ConstantRange::getFull(MaxIntegerBW + 1); } ConstantRange Float2IntPass::unknownRange() { - return ConstantRange(MaxIntegerBW + 1, false); + return ConstantRange::getEmpty(MaxIntegerBW + 1); } ConstantRange Float2IntPass::validateRange(ConstantRange R) { if (R.getBitWidth() > MaxIntegerBW + 1) @@ -195,12 +194,13 @@ void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { // Path terminated cleanly - use the type of the integer input to seed // the analysis. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); - auto Input = ConstantRange(BW, true); + auto Input = ConstantRange::getFull(BW); auto CastOp = (Instruction::CastOps)I->getOpcode(); seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1))); continue; } + case Instruction::FNeg: case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: @@ -241,6 +241,15 @@ void Float2IntPass::walkForwards() { case Instruction::SIToFP: llvm_unreachable("Should have been handled in walkForwards!"); + case Instruction::FNeg: + Op = [](ArrayRef<ConstantRange> Ops) { + assert(Ops.size() == 1 && "FNeg is a unary operator!"); + unsigned Size = Ops[0].getBitWidth(); + auto Zero = ConstantRange(APInt::getNullValue(Size)); + return Zero.sub(Ops[0]); + }; + break; + case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: @@ -427,7 +436,7 @@ Value *Float2IntPass::convert(Instruction *I, Type *ToTy) { } else if (Instruction *VI = dyn_cast<Instruction>(V)) { NewOperands.push_back(convert(VI, ToTy)); } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) { - APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false); + APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false); bool Exact; CF->getValueAPF().convertToInteger(Val, APFloat::rmNearestTiesToEven, @@ -467,6 +476,10 @@ Value *Float2IntPass::convert(Instruction *I, Type *ToTy) { NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy); break; + case Instruction::FNeg: + NewV = IRB.CreateNeg(NewOperands[0], I->getName()); + break; + case Instruction::FAdd: case Instruction::FSub: case Instruction::FMul: diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index 9861948c8297..1a02e9d33f49 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -1,9 +1,8 @@ //===- GVN.cpp - Eliminate redundant values and loads ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -30,6 +29,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -46,8 +46,8 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -330,36 +330,15 @@ GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { e.type = EI->getType(); e.opcode = 0; - IntrinsicInst *I = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); - if (I != nullptr && EI->getNumIndices() == 1 && *EI->idx_begin() == 0 ) { - // EI might be an extract from one of our recognised intrinsics. If it - // is we'll synthesize a semantically equivalent expression instead on - // an extract value expression. - switch (I->getIntrinsicID()) { - case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_with_overflow: - e.opcode = Instruction::Add; - break; - case Intrinsic::ssub_with_overflow: - case Intrinsic::usub_with_overflow: - e.opcode = Instruction::Sub; - break; - case Intrinsic::smul_with_overflow: - case Intrinsic::umul_with_overflow: - e.opcode = Instruction::Mul; - break; - default: - break; - } - - if (e.opcode != 0) { - // Intrinsic recognized. Grab its args to finish building the expression. - assert(I->getNumArgOperands() == 2 && - "Expect two args for recognised intrinsics."); - e.varargs.push_back(lookupOrAdd(I->getArgOperand(0))); - e.varargs.push_back(lookupOrAdd(I->getArgOperand(1))); - return e; - } + WithOverflowInst *WO = dyn_cast<WithOverflowInst>(EI->getAggregateOperand()); + if (WO != nullptr && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { + // EI is an extract from one of our with.overflow intrinsics. Synthesize + // a semantically equivalent expression instead of an extract value + // expression. + e.opcode = WO->getBinaryOp(); + e.varargs.push_back(lookupOrAdd(WO->getLHS())); + e.varargs.push_back(lookupOrAdd(WO->getRHS())); + return e; } // Not a recognised intrinsic. Fall back to producing an extract value @@ -513,6 +492,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { switch (I->getOpcode()) { case Instruction::Call: return lookupOrAddCall(cast<CallInst>(I)); + case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -544,6 +524,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { case Instruction::FPExt: case Instruction::PtrToInt: case Instruction::IntToPtr: + case Instruction::AddrSpaceCast: case Instruction::BitCast: case Instruction::Select: case Instruction::ExtractElement: @@ -879,11 +860,12 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, const DataLayout &DL = LI->getModule()->getDataLayout(); + Instruction *DepInst = DepInfo.getInst(); if (DepInfo.isClobber()) { // If the dependence is to a store that writes to a superset of the bits // read by the load, we can extract the bits we need for the load from the // stored value. - if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInfo.getInst())) { + if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInst)) { // Can't forward from non-atomic to atomic without violating memory model. if (Address && LI->isAtomic() <= DepSI->isAtomic()) { int Offset = @@ -899,7 +881,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // load i32* P // load i8* (P+1) // if we have this, replace the later with an extraction from the former. - if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInfo.getInst())) { + if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { // If this is a clobber and L is the first instruction in its block, then // we have the first instruction in the entry block. // Can't forward from non-atomic to atomic without violating memory model. @@ -916,7 +898,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // If the clobbering value is a memset/memcpy/memmove, see if we can // forward a value on from it. - if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) { + if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { if (Address && !LI->isAtomic()) { int Offset = analyzeLoadFromClobberingMemInst(LI->getType(), Address, DepMI, DL); @@ -930,8 +912,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, LLVM_DEBUG( // fast print dep, using operator<< on instruction is too slow. dbgs() << "GVN: load "; LI->printAsOperand(dbgs()); - Instruction *I = DepInfo.getInst(); - dbgs() << " is clobbered by " << *I << '\n';); + dbgs() << " is clobbered by " << *DepInst << '\n';); if (ORE->allowExtraAnalysis(DEBUG_TYPE)) reportMayClobberedLoad(LI, DepInfo, DT, ORE); @@ -939,8 +920,6 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, } assert(DepInfo.isDef() && "follows from above"); - Instruction *DepInst = DepInfo.getInst(); - // Loading the allocation -> undef. if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || // Loading immediately after lifetime begin -> undef. @@ -959,9 +938,8 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // Reject loads and stores that are to the same address but are of // different types if we have to. If the stored value is larger or equal to // the loaded value, we can reuse it. - if (S->getValueOperand()->getType() != LI->getType() && - !canCoerceMustAliasedValueToLoad(S->getValueOperand(), - LI->getType(), DL)) + if (!canCoerceMustAliasedValueToLoad(S->getValueOperand(), LI->getType(), + DL)) return false; // Can't forward from non-atomic to atomic without violating memory model. @@ -976,8 +954,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, // If the types mismatch and we can't handle it, reject reuse of the load. // If the stored value is larger or equal to the loaded value, we can reuse // it. - if (LD->getType() != LI->getType() && - !canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) + if (!canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) return false; // Can't forward from non-atomic to atomic without violating memory model. @@ -1132,6 +1109,14 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, return false; } + // FIXME: Can we support the fallthrough edge? + if (isa<CallBrInst>(Pred->getTerminator())) { + LLVM_DEBUG( + dbgs() << "COULD NOT PRE LOAD BECAUSE OF CALLBR CRITICAL EDGE '" + << Pred->getName() << "': " << *LI << '\n'); + return false; + } + if (LoadBB->isEHPad()) { LLVM_DEBUG( dbgs() << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD CRITICAL EDGE '" @@ -1220,9 +1205,8 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // Instructions that have been inserted in predecessor(s) to materialize // the load address do not retain their original debug locations. Doing // so could lead to confusing (but correct) source attributions. - // FIXME: How do we retain source locations without causing poor debugging - // behavior? - I->setDebugLoc(DebugLoc()); + if (const DebugLoc &DL = I->getDebugLoc()) + I->setDebugLoc(DebugLoc::get(0, 0, DL.getScope(), DL.getInlinedAt())); // FIXME: We really _ought_ to insert these value numbers into their // parent's availability map. However, in doing so, we risk getting into @@ -1235,10 +1219,10 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, BasicBlock *UnavailablePred = PredLoad.first; Value *LoadPtr = PredLoad.second; - auto *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", - LI->isVolatile(), LI->getAlignment(), - LI->getOrdering(), LI->getSyncScopeID(), - UnavailablePred->getTerminator()); + auto *NewLoad = + new LoadInst(LI->getType(), LoadPtr, LI->getName() + ".pre", + LI->isVolatile(), LI->getAlignment(), LI->getOrdering(), + LI->getSyncScopeID(), UnavailablePred->getTerminator()); NewLoad->setDebugLoc(LI->getDebugLoc()); // Transfer the old load's AA tags to the new load. @@ -2168,8 +2152,8 @@ bool GVN::performScalarPRE(Instruction *CurInst) { return false; // We don't currently value number ANY inline asm calls. - if (CallInst *CallI = dyn_cast<CallInst>(CurInst)) - if (CallI->isInlineAsm()) + if (auto *CallB = dyn_cast<CallBase>(CurInst)) + if (CallB->isInlineAsm()) return false; uint32_t ValNo = VN.lookup(CurInst); @@ -2252,6 +2236,11 @@ bool GVN::performScalarPRE(Instruction *CurInst) { if (isa<IndirectBrInst>(PREPred->getTerminator())) return false; + // Don't do PRE across callbr. + // FIXME: Can we do this across the fallthrough edge? + if (isa<CallBrInst>(PREPred->getTerminator())) + return false; + // We can't do PRE safely on a critical edge, so instead we schedule // the edge to be split and perform the PRE the next time we iterate // on the function. @@ -2479,8 +2468,7 @@ void GVN::addDeadBlock(BasicBlock *BB) { for (BasicBlock::iterator II = B->begin(); isa<PHINode>(II); ++II) { PHINode &Phi = cast<PHINode>(*II); - Phi.setIncomingValue(Phi.getBasicBlockIndex(P), - UndefValue::get(Phi.getType())); + Phi.setIncomingValueForBlock(P, UndefValue::get(Phi.getType())); if (MD) MD->invalidateCachedPointerInfo(&Phi); } diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index 76a42d7fe750..7614599653c4 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -1,9 +1,8 @@ //===- GVNHoist.cpp - Hoist scalar and load expressions -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -703,7 +702,7 @@ private: // Vector of PHIs contains PHIs for different instructions. // Sort the args according to their VNs, such that identical // instructions are together. - std::stable_sort(CHIs.begin(), CHIs.end(), cmpVN); + llvm::stable_sort(CHIs, cmpVN); auto TI = BB->getTerminator(); auto B = CHIs.begin(); // [PreIt, PHIIt) form a range of CHIs which have identical VNs. diff --git a/lib/Transforms/Scalar/GVNSink.cpp b/lib/Transforms/Scalar/GVNSink.cpp index 1df5f5400c14..054025755c69 100644 --- a/lib/Transforms/Scalar/GVNSink.cpp +++ b/lib/Transforms/Scalar/GVNSink.cpp @@ -1,9 +1,8 @@ //===- GVNSink.cpp - sink expressions into successors ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -442,6 +441,7 @@ public: break; case Instruction::Call: case Instruction::Invoke: + case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -714,6 +714,15 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( // FIXME: If any of these fail, we should partition up the candidates to // try and continue making progress. Instruction *I0 = NewInsts[0]; + + // If all instructions that are going to participate don't have the same + // number of operands, we can't do any useful PHI analysis for all operands. + auto hasDifferentNumOperands = [&I0](Instruction *I) { + return I->getNumOperands() != I0->getNumOperands(); + }; + if (any_of(NewInsts, hasDifferentNumOperands)) + return None; + for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) { ModelledPHI PHI(NewInsts, OpNum, ActivePreds); if (PHI.areAllIncomingValuesSame()) @@ -791,10 +800,7 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { --LRI; } - std::stable_sort( - Candidates.begin(), Candidates.end(), - [](const SinkingInstructionCandidate &A, - const SinkingInstructionCandidate &B) { return A > B; }); + llvm::stable_sort(Candidates, std::greater<SinkingInstructionCandidate>()); LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C : Candidates) dbgs() << " " << C << "\n";); diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index efc204d4f74b..e14f44bb7069 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -1,9 +1,8 @@ //===- GuardWidening.cpp - ---- Guard widening ----------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -83,6 +82,11 @@ static cl::opt<unsigned> FrequentBranchThreshold( "it is considered frequently taken"), cl::init(1000)); +static cl::opt<bool> + WidenBranchGuards("guard-widening-widen-branch-guards", cl::Hidden, + cl::desc("Whether or not we should widen guards " + "expressed as branches by widenable conditions"), + cl::init(true)); namespace { @@ -93,6 +97,10 @@ static Value *getCondition(Instruction *I) { "Bad guard intrinsic?"); return GI->getArgOperand(0); } + if (isGuardAsWidenableBranch(I)) { + auto *Cond = cast<BranchInst>(I)->getCondition(); + return cast<BinaryOperator>(Cond)->getOperand(0); + } return cast<BranchInst>(I)->getCondition(); } @@ -133,12 +141,12 @@ class GuardWideningImpl { /// guards. DenseSet<Instruction *> WidenedGuards; - /// Try to eliminate guard \p Guard by widening it into an earlier dominating - /// guard. \p DFSI is the DFS iterator on the dominator tree that is - /// currently visiting the block containing \p Guard, and \p GuardsPerBlock + /// Try to eliminate instruction \p Instr by widening it into an earlier + /// dominating guard. \p DFSI is the DFS iterator on the dominator tree that + /// is currently visiting the block containing \p Guard, and \p GuardsPerBlock /// maps BasicBlocks to the set of guards seen in that block. - bool eliminateGuardViaWidening( - Instruction *Guard, const df_iterator<DomTreeNode *> &DFSI, + bool eliminateInstrViaWidening( + Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI, const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & GuardsPerBlock, bool InvertCondition = false); @@ -162,28 +170,25 @@ class GuardWideningImpl { static StringRef scoreTypeToString(WideningScore WS); - /// Compute the score for widening the condition in \p DominatedGuard - /// (contained in \p DominatedGuardLoop) into \p DominatingGuard (contained in - /// \p DominatingGuardLoop). If \p InvertCond is set, then we widen the + /// Compute the score for widening the condition in \p DominatedInstr + /// into \p DominatingGuard. If \p InvertCond is set, then we widen the /// inverted condition of the dominating guard. - WideningScore computeWideningScore(Instruction *DominatedGuard, - Loop *DominatedGuardLoop, + WideningScore computeWideningScore(Instruction *DominatedInstr, Instruction *DominatingGuard, - Loop *DominatingGuardLoop, bool InvertCond); /// Helper to check if \p V can be hoisted to \p InsertPos. - bool isAvailableAt(Value *V, Instruction *InsertPos) { - SmallPtrSet<Instruction *, 8> Visited; + bool isAvailableAt(const Value *V, const Instruction *InsertPos) const { + SmallPtrSet<const Instruction *, 8> Visited; return isAvailableAt(V, InsertPos, Visited); } - bool isAvailableAt(Value *V, Instruction *InsertPos, - SmallPtrSetImpl<Instruction *> &Visited); + bool isAvailableAt(const Value *V, const Instruction *InsertPos, + SmallPtrSetImpl<const Instruction *> &Visited) const; /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c /// isAvailableAt returned true. - void makeAvailableAt(Value *V, Instruction *InsertPos); + void makeAvailableAt(Value *V, Instruction *InsertPos) const; /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try /// to generate an expression computing the logical AND of \p Cond0 and (\p @@ -200,23 +205,23 @@ class GuardWideningImpl { /// pre-existing instruction in the IR that computes the result of this range /// check. class RangeCheck { - Value *Base; - ConstantInt *Offset; - Value *Length; + const Value *Base; + const ConstantInt *Offset; + const Value *Length; ICmpInst *CheckInst; public: - explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length, - ICmpInst *CheckInst) + explicit RangeCheck(const Value *Base, const ConstantInt *Offset, + const Value *Length, ICmpInst *CheckInst) : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {} - void setBase(Value *NewBase) { Base = NewBase; } - void setOffset(ConstantInt *NewOffset) { Offset = NewOffset; } + void setBase(const Value *NewBase) { Base = NewBase; } + void setOffset(const ConstantInt *NewOffset) { Offset = NewOffset; } - Value *getBase() const { return Base; } - ConstantInt *getOffset() const { return Offset; } + const Value *getBase() const { return Base; } + const ConstantInt *getOffset() const { return Offset; } const APInt &getOffsetValue() const { return getOffset()->getValue(); } - Value *getLength() const { return Length; }; + const Value *getLength() const { return Length; }; ICmpInst *getCheckInst() const { return CheckInst; } void print(raw_ostream &OS, bool PrintTypes = false) { @@ -238,19 +243,19 @@ class GuardWideningImpl { /// append them to \p Checks. Returns true on success, may clobber \c Checks /// on failure. bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) { - SmallPtrSet<Value *, 8> Visited; + SmallPtrSet<const Value *, 8> Visited; return parseRangeChecks(CheckCond, Checks, Visited); } bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks, - SmallPtrSetImpl<Value *> &Visited); + SmallPtrSetImpl<const Value *> &Visited); /// Combine the checks in \p Checks into a smaller set of checks and append /// them into \p CombinedChecks. Return true on success (i.e. all of checks /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks /// and \p CombinedChecks on success and on failure. bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks, - SmallVectorImpl<RangeCheck> &CombinedChecks); + SmallVectorImpl<RangeCheck> &CombinedChecks) const; /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of /// computing only one of the two expressions? @@ -266,8 +271,16 @@ class GuardWideningImpl { void widenGuard(Instruction *ToWiden, Value *NewCondition, bool InvertCondition) { Value *Result; - widenCondCommon(ToWiden->getOperand(0), NewCondition, ToWiden, Result, + widenCondCommon(getCondition(ToWiden), NewCondition, ToWiden, Result, InvertCondition); + Value *WidenableCondition = nullptr; + if (isGuardAsWidenableBranch(ToWiden)) { + auto *Cond = cast<BranchInst>(ToWiden)->getCondition(); + WidenableCondition = cast<BinaryOperator>(Cond)->getOperand(1); + } + if (WidenableCondition) + Result = BinaryOperator::CreateAnd(Result, WidenableCondition, + "guard.chk", ToWiden); setCondition(ToWiden, Result); } @@ -285,6 +298,14 @@ public: }; } +static bool isSupportedGuardInstruction(const Instruction *Insn) { + if (isGuard(Insn)) + return true; + if (WidenBranchGuards && isGuardAsWidenableBranch(Insn)) + return true; + return false; +} + bool GuardWideningImpl::run() { DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> GuardsInBlock; bool Changed = false; @@ -304,20 +325,20 @@ bool GuardWideningImpl::run() { auto &CurrentList = GuardsInBlock[BB]; for (auto &I : *BB) - if (isGuard(&I)) + if (isSupportedGuardInstruction(&I)) CurrentList.push_back(cast<Instruction>(&I)); for (auto *II : CurrentList) - Changed |= eliminateGuardViaWidening(II, DFI, GuardsInBlock); + Changed |= eliminateInstrViaWidening(II, DFI, GuardsInBlock); if (WidenFrequentBranches && BPI) if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) if (BI->isConditional()) { // If one of branches of a conditional is likely taken, try to // eliminate it. if (BPI->getEdgeProbability(BB, 0U) >= *LikelyTaken) - Changed |= eliminateGuardViaWidening(BI, DFI, GuardsInBlock); + Changed |= eliminateInstrViaWidening(BI, DFI, GuardsInBlock); else if (BPI->getEdgeProbability(BB, 1U) >= *LikelyTaken) - Changed |= eliminateGuardViaWidening(BI, DFI, GuardsInBlock, + Changed |= eliminateInstrViaWidening(BI, DFI, GuardsInBlock, /*InvertCondition*/true); } } @@ -326,7 +347,7 @@ bool GuardWideningImpl::run() { for (auto *I : EliminatedGuardsAndBranches) if (!WidenedGuards.count(I)) { assert(isa<ConstantInt>(getCondition(I)) && "Should be!"); - if (isGuard(I)) + if (isSupportedGuardInstruction(I)) eliminateGuard(I); else { assert(isa<BranchInst>(I) && @@ -338,19 +359,18 @@ bool GuardWideningImpl::run() { return Changed; } -bool GuardWideningImpl::eliminateGuardViaWidening( - Instruction *GuardInst, const df_iterator<DomTreeNode *> &DFSI, +bool GuardWideningImpl::eliminateInstrViaWidening( + Instruction *Instr, const df_iterator<DomTreeNode *> &DFSI, const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & GuardsInBlock, bool InvertCondition) { // Ignore trivial true or false conditions. These instructions will be // trivially eliminated by any cleanup pass. Do not erase them because other // guards can possibly be widened into them. - if (isa<ConstantInt>(getCondition(GuardInst))) + if (isa<ConstantInt>(getCondition(Instr))) return false; Instruction *BestSoFar = nullptr; auto BestScoreSoFar = WS_IllegalOrNegative; - auto *GuardInstLoop = LI.getLoopFor(GuardInst->getParent()); // In the set of dominating guards, find the one we can merge GuardInst with // for the most profit. @@ -358,12 +378,13 @@ bool GuardWideningImpl::eliminateGuardViaWidening( auto *CurBB = DFSI.getPath(i)->getBlock(); if (!BlockFilter(CurBB)) break; - auto *CurLoop = LI.getLoopFor(CurBB); assert(GuardsInBlock.count(CurBB) && "Must have been populated by now!"); const auto &GuardsInCurBB = GuardsInBlock.find(CurBB)->second; auto I = GuardsInCurBB.begin(); - auto E = GuardsInCurBB.end(); + auto E = Instr->getParent() == CurBB + ? std::find(GuardsInCurBB.begin(), GuardsInCurBB.end(), Instr) + : GuardsInCurBB.end(); #ifndef NDEBUG { @@ -379,21 +400,11 @@ bool GuardWideningImpl::eliminateGuardViaWidening( } #endif - assert((i == (e - 1)) == (GuardInst->getParent() == CurBB) && "Bad DFS?"); - - if (i == (e - 1) && CurBB->getTerminator() != GuardInst) { - // Corner case: make sure we're only looking at guards strictly dominating - // GuardInst when visiting GuardInst->getParent(). - auto NewEnd = std::find(I, E, GuardInst); - assert(NewEnd != E && "GuardInst not in its own block?"); - E = NewEnd; - } + assert((i == (e - 1)) == (Instr->getParent() == CurBB) && "Bad DFS?"); for (auto *Candidate : make_range(I, E)) { - auto Score = - computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop, - InvertCondition); - LLVM_DEBUG(dbgs() << "Score between " << *getCondition(GuardInst) + auto Score = computeWideningScore(Instr, Candidate, InvertCondition); + LLVM_DEBUG(dbgs() << "Score between " << *getCondition(Instr) << " and " << *getCondition(Candidate) << " is " << scoreTypeToString(Score) << "\n"); if (Score > BestScoreSoFar) { @@ -404,42 +415,45 @@ bool GuardWideningImpl::eliminateGuardViaWidening( } if (BestScoreSoFar == WS_IllegalOrNegative) { - LLVM_DEBUG(dbgs() << "Did not eliminate guard " << *GuardInst << "\n"); + LLVM_DEBUG(dbgs() << "Did not eliminate guard " << *Instr << "\n"); return false; } - assert(BestSoFar != GuardInst && "Should have never visited same guard!"); - assert(DT.dominates(BestSoFar, GuardInst) && "Should be!"); + assert(BestSoFar != Instr && "Should have never visited same guard!"); + assert(DT.dominates(BestSoFar, Instr) && "Should be!"); - LLVM_DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar + LLVM_DEBUG(dbgs() << "Widening " << *Instr << " into " << *BestSoFar << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); - widenGuard(BestSoFar, getCondition(GuardInst), InvertCondition); + widenGuard(BestSoFar, getCondition(Instr), InvertCondition); auto NewGuardCondition = InvertCondition - ? ConstantInt::getFalse(GuardInst->getContext()) - : ConstantInt::getTrue(GuardInst->getContext()); - setCondition(GuardInst, NewGuardCondition); - EliminatedGuardsAndBranches.push_back(GuardInst); + ? ConstantInt::getFalse(Instr->getContext()) + : ConstantInt::getTrue(Instr->getContext()); + setCondition(Instr, NewGuardCondition); + EliminatedGuardsAndBranches.push_back(Instr); WidenedGuards.insert(BestSoFar); return true; } -GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( - Instruction *DominatedGuard, Loop *DominatedGuardLoop, - Instruction *DominatingGuard, Loop *DominatingGuardLoop, bool InvertCond) { +GuardWideningImpl::WideningScore +GuardWideningImpl::computeWideningScore(Instruction *DominatedInstr, + Instruction *DominatingGuard, + bool InvertCond) { + Loop *DominatedInstrLoop = LI.getLoopFor(DominatedInstr->getParent()); + Loop *DominatingGuardLoop = LI.getLoopFor(DominatingGuard->getParent()); bool HoistingOutOfLoop = false; - if (DominatingGuardLoop != DominatedGuardLoop) { + if (DominatingGuardLoop != DominatedInstrLoop) { // Be conservative and don't widen into a sibling loop. TODO: If the // sibling is colder, we should consider allowing this. if (DominatingGuardLoop && - !DominatingGuardLoop->contains(DominatedGuardLoop)) + !DominatingGuardLoop->contains(DominatedInstrLoop)) return WS_IllegalOrNegative; HoistingOutOfLoop = true; } - if (!isAvailableAt(getCondition(DominatedGuard), DominatingGuard)) + if (!isAvailableAt(getCondition(DominatedInstr), DominatingGuard)) return WS_IllegalOrNegative; // If the guard was conditional executed, it may never be reached @@ -450,7 +464,7 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( // here. TODO: evaluate cost model for spurious deopt // NOTE: As written, this also lets us hoist right over another guard which // is essentially just another spelling for control flow. - if (isWideningCondProfitable(getCondition(DominatedGuard), + if (isWideningCondProfitable(getCondition(DominatedInstr), getCondition(DominatingGuard), InvertCond)) return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; @@ -462,7 +476,9 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( // throw, etc...). That choice appears arbitrary. auto MaybeHoistingOutOfIf = [&]() { auto *DominatingBlock = DominatingGuard->getParent(); - auto *DominatedBlock = DominatedGuard->getParent(); + auto *DominatedBlock = DominatedInstr->getParent(); + if (isGuardAsWidenableBranch(DominatingGuard)) + DominatingBlock = cast<BranchInst>(DominatingGuard)->getSuccessor(0); // Same Block? if (DominatedBlock == DominatingBlock) @@ -478,8 +494,9 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( return MaybeHoistingOutOfIf() ? WS_IllegalOrNegative : WS_Neutral; } -bool GuardWideningImpl::isAvailableAt(Value *V, Instruction *Loc, - SmallPtrSetImpl<Instruction *> &Visited) { +bool GuardWideningImpl::isAvailableAt( + const Value *V, const Instruction *Loc, + SmallPtrSetImpl<const Instruction *> &Visited) const { auto *Inst = dyn_cast<Instruction>(V); if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) return true; @@ -499,7 +516,7 @@ bool GuardWideningImpl::isAvailableAt(Value *V, Instruction *Loc, [&](Value *Op) { return isAvailableAt(Op, Loc, Visited); }); } -void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) { +void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { auto *Inst = dyn_cast<Instruction>(V); if (!Inst || DT.dominates(Inst, Loc)) return; @@ -597,7 +614,7 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, bool GuardWideningImpl::parseRangeChecks( Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, - SmallPtrSetImpl<Value *> &Visited) { + SmallPtrSetImpl<const Value *> &Visited) { if (!Visited.insert(CheckCond).second) return true; @@ -616,7 +633,7 @@ bool GuardWideningImpl::parseRangeChecks( IC->getPredicate() != ICmpInst::ICMP_UGT)) return false; - Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); + const Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); if (IC->getPredicate() == ICmpInst::ICMP_UGT) std::swap(CmpLHS, CmpRHS); @@ -669,13 +686,13 @@ bool GuardWideningImpl::parseRangeChecks( bool GuardWideningImpl::combineRangeChecks( SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, - SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) { + SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) const { unsigned OldCount = Checks.size(); while (!Checks.empty()) { // Pick all of the range checks with a specific base and length, and try to // merge them. - Value *CurrentBase = Checks.front().getBase(); - Value *CurrentLength = Checks.front().getLength(); + const Value *CurrentBase = Checks.front().getBase(); + const Value *CurrentLength = Checks.front().getLength(); SmallVector<GuardWideningImpl::RangeCheck, 3> CurrentChecks; @@ -704,8 +721,8 @@ bool GuardWideningImpl::combineRangeChecks( // Note: std::sort should not invalidate the ChecksStart iterator. - ConstantInt *MinOffset = CurrentChecks.front().getOffset(), - *MaxOffset = CurrentChecks.back().getOffset(); + const ConstantInt *MinOffset = CurrentChecks.front().getOffset(); + const ConstantInt *MaxOffset = CurrentChecks.back().getOffset(); unsigned BitWidth = MaxOffset->getValue().getBitWidth(); if ((MaxOffset->getValue() - MinOffset->getValue()) @@ -800,6 +817,31 @@ PreservedAnalyses GuardWideningPass::run(Function &F, return PA; } +PreservedAnalyses GuardWideningPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function &F = *L.getHeader()->getParent(); + BranchProbabilityInfo *BPI = nullptr; + if (WidenFrequentBranches) + BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(F); + + BasicBlock *RootBB = L.getLoopPredecessor(); + if (!RootBB) + RootBB = L.getHeader(); + auto BlockFilter = [&](BasicBlock *BB) { + return BB == RootBB || L.contains(BB); + }; + if (!GuardWideningImpl(AR.DT, nullptr, AR.LI, BPI, + AR.DT.getNode(RootBB), + BlockFilter).run()) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + namespace { struct GuardWideningLegacyPass : public FunctionPass { static char ID; diff --git a/lib/Transforms/Scalar/IVUsersPrinter.cpp b/lib/Transforms/Scalar/IVUsersPrinter.cpp index 807593379283..e2022aba97c4 100644 --- a/lib/Transforms/Scalar/IVUsersPrinter.cpp +++ b/lib/Transforms/Scalar/IVUsersPrinter.cpp @@ -1,9 +1,8 @@ //===- IVUsersPrinter.cpp - Induction Variable Users Printer ----*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 48d8e457ba7c..f9fc698a4a9b 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1,9 +1,8 @@ //===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -32,6 +31,7 @@ #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -43,6 +43,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -101,7 +102,7 @@ static cl::opt<bool> VerifyIndvars( "verify-indvars", cl::Hidden, cl::desc("Verify the ScalarEvolution result after running indvars")); -enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, AlwaysRepl }; +enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, NoHardUse, AlwaysRepl }; static cl::opt<ReplaceExitVal> ReplaceExitValue( "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), @@ -109,6 +110,8 @@ static cl::opt<ReplaceExitVal> ReplaceExitValue( cl::values(clEnumValN(NeverRepl, "never", "never replace exit value"), clEnumValN(OnlyCheapRepl, "cheap", "only replace exit value when the cost is cheap"), + clEnumValN(NoHardUse, "noharduse", + "only replace exit values when loop def likely dead"), clEnumValN(AlwaysRepl, "always", "always replace exit value whenever possible"))); @@ -141,13 +144,15 @@ class IndVarSimplify { bool rewriteNonIntegerIVs(Loop *L); bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); + bool optimizeLoopExits(Loop *L); bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); bool rewriteFirstIterationLoopExitValues(Loop *L); bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const; - bool linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, + bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, + const SCEV *ExitCount, PHINode *IndVar, SCEVExpander &Rewriter); bool sinkUnusedInvariants(Loop *L); @@ -218,7 +223,9 @@ bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { /// Determine the insertion point for this user. By default, insert immediately /// before the user. SCEVExpander or LICM will hoist loop invariants out of the /// loop. For PHI nodes, there may be multiple uses, so compute the nearest -/// common dominator for the incoming blocks. +/// common dominator for the incoming blocks. A nullptr can be returned if no +/// viable location is found: it may happen if User is a PHI and Def only comes +/// to this PHI from unreachable blocks. static Instruction *getInsertPointForUses(Instruction *User, Value *Def, DominatorTree *DT, LoopInfo *LI) { PHINode *PHI = dyn_cast<PHINode>(User); @@ -231,6 +238,10 @@ static Instruction *getInsertPointForUses(Instruction *User, Value *Def, continue; BasicBlock *InsertBB = PHI->getIncomingBlock(i); + + if (!DT->isReachableFromEntry(InsertBB)) + continue; + if (!InsertPt) { InsertPt = InsertBB->getTerminator(); continue; @@ -238,7 +249,11 @@ static Instruction *getInsertPointForUses(Instruction *User, Value *Def, InsertBB = DT->findNearestCommonDominator(InsertPt->getParent(), InsertBB); InsertPt = InsertBB->getTerminator(); } - assert(InsertPt && "Missing phi operand"); + + // If we have skipped all inputs, it means that Def only comes to Phi from + // unreachable blocks. + if (!InsertPt) + return nullptr; auto *DefI = dyn_cast<Instruction>(Def); if (!DefI) @@ -621,8 +636,12 @@ bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Computing the value outside of the loop brings no benefit if it is // definitely used inside the loop in a way which can not be optimized - // away. - if (!isa<SCEVConstant>(ExitValue) && hasHardUserWithinLoop(L, Inst)) + // away. Avoid doing so unless we know we have a value which computes + // the ExitValue already. TODO: This should be merged into SCEV + // expander to leverage its knowledge of existing expressions. + if (ReplaceExitValue != AlwaysRepl && + !isa<SCEVConstant>(ExitValue) && !isa<SCEVUnknown>(ExitValue) && + hasHardUserWithinLoop(L, Inst)) continue; bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); @@ -707,8 +726,6 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { SmallVector<BasicBlock *, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); - auto *LoopHeader = L->getHeader(); - assert(LoopHeader && "Invalid loop"); bool MadeAnyChanges = false; for (auto *ExitBB : ExitBlocks) { @@ -719,11 +736,13 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { IncomingValIdx != E; ++IncomingValIdx) { auto *IncomingBB = PN.getIncomingBlock(IncomingValIdx); - // We currently only support loop exits from loop header. If the - // incoming block is not loop header, we need to recursively check - // all conditions starting from loop header are loop invariants. - // Additional support might be added in the future. - if (IncomingBB != LoopHeader) + // Can we prove that the exit must run on the first iteration if it + // runs at all? (i.e. early exits are fine for our purposes, but + // traces which lead to this exit being taken on the 2nd iteration + // aren't.) Note that this is about whether the exit branch is + // executed, not about whether it is taken. + if (!L->getLoopLatch() || + !DT->dominates(IncomingBB, L->getLoopLatch())) continue; // Get condition that leads to the exit path. @@ -744,8 +763,8 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { auto *ExitVal = dyn_cast<PHINode>(PN.getIncomingValue(IncomingValIdx)); - // Only deal with PHIs. - if (!ExitVal) + // Only deal with PHIs in the loop header. + if (!ExitVal || ExitVal->getParent() != L->getHeader()) continue; // If ExitVal is a PHI on the loop header, then we know its @@ -755,7 +774,7 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { assert(LoopPreheader && "Invalid loop"); int PreheaderIdx = ExitVal->getBasicBlockIndex(LoopPreheader); if (PreheaderIdx != -1) { - assert(ExitVal->getParent() == LoopHeader && + assert(ExitVal->getParent() == L->getHeader() && "ExitVal must be in loop header"); MadeAnyChanges = true; PN.setIncomingValue(IncomingValIdx, @@ -1022,24 +1041,13 @@ protected: } // end anonymous namespace -/// Perform a quick domtree based check for loop invariance assuming that V is -/// used within the loop. LoopInfo::isLoopInvariant() seems gratuitous for this -/// purpose. -static bool isLoopInvariant(Value *V, const Loop *L, const DominatorTree *DT) { - Instruction *Inst = dyn_cast<Instruction>(V); - if (!Inst) - return true; - - return DT->properlyDominates(Inst->getParent(), L->getHeader()); -} - Value *WidenIV::createExtendInst(Value *NarrowOper, Type *WideType, bool IsSigned, Instruction *Use) { // Set the debug location and conservative insertion point. IRBuilder<> Builder(Use); // Hoist the insertion point into loop preheaders as far as possible. for (const Loop *L = LI->getLoopFor(Use->getParent()); - L && L->getLoopPreheader() && isLoopInvariant(NarrowOper, L, DT); + L && L->getLoopPreheader() && L->isLoopInvariant(NarrowOper); L = L->getParentLoop()) Builder.SetInsertPoint(L->getLoopPreheader()->getTerminator()); @@ -1305,13 +1313,15 @@ WidenIV::WidenedRecTy WidenIV::getWideRecurrence(NarrowIVDefUse DU) { return {AddRec, ExtKind}; } -/// This IV user cannot be widen. Replace this use of the original narrow IV +/// This IV user cannot be widened. Replace this use of the original narrow IV /// with a truncation of the new wide IV to isolate and eliminate the narrow IV. static void truncateIVUse(NarrowIVDefUse DU, DominatorTree *DT, LoopInfo *LI) { + auto *InsertPt = getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI); + if (!InsertPt) + return; LLVM_DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef << " for user " << *DU.NarrowUse << "\n"); - IRBuilder<> Builder( - getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); + IRBuilder<> Builder(InsertPt); Value *Trunc = Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType()); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, Trunc); } @@ -1348,8 +1358,10 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { assert(CastWidth <= IVWidth && "Unexpected width while widening compare."); // Widen the compare instruction. - IRBuilder<> Builder( - getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); + auto *InsertPt = getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI); + if (!InsertPt) + return false; + IRBuilder<> Builder(InsertPt); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); // Widen the other operand of the compare, if necessary. @@ -1977,41 +1989,10 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L, // linearFunctionTestReplace and its kin. Rewrite the loop exit condition. //===----------------------------------------------------------------------===// -/// Return true if this loop's backedge taken count expression can be safely and -/// cheaply expanded into an instruction sequence that can be used by -/// linearFunctionTestReplace. -/// -/// TODO: This fails for pointer-type loop counters with greater than one byte -/// strides, consequently preventing LFTR from running. For the purpose of LFTR -/// we could skip this check in the case that the LFTR loop counter (chosen by -/// FindLoopCounter) is also pointer type. Instead, we could directly convert -/// the loop test to an inequality test by checking the target data's alignment -/// of element types (given that the initial pointer value originates from or is -/// used by ABI constrained operation, as opposed to inttoptr/ptrtoint). -/// However, we don't yet have a strong motivation for converting loop tests -/// into inequality tests. -static bool canExpandBackedgeTakenCount(Loop *L, ScalarEvolution *SE, - SCEVExpander &Rewriter) { - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); - if (isa<SCEVCouldNotCompute>(BackedgeTakenCount) || - BackedgeTakenCount->isZero()) - return false; - - if (!L->getExitingBlock()) - return false; - - // Can't rewrite non-branch yet. - if (!isa<BranchInst>(L->getExitingBlock()->getTerminator())) - return false; - - if (Rewriter.isHighCostExpansion(BackedgeTakenCount, L)) - return false; - - return true; -} - -/// Return the loop header phi IFF IncV adds a loop invariant value to the phi. -static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { +/// Given an Value which is hoped to be part of an add recurance in the given +/// loop, return the associated Phi node if so. Otherwise, return null. Note +/// that this is less general than SCEVs AddRec checking. +static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) { Instruction *IncI = dyn_cast<Instruction>(IncV); if (!IncI) return nullptr; @@ -2031,7 +2012,7 @@ static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { PHINode *Phi = dyn_cast<PHINode>(IncI->getOperand(0)); if (Phi && Phi->getParent() == L->getHeader()) { - if (isLoopInvariant(IncI->getOperand(1), L, DT)) + if (L->isLoopInvariant(IncI->getOperand(1))) return Phi; return nullptr; } @@ -2041,32 +2022,40 @@ static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { // Allow add/sub to be commuted. Phi = dyn_cast<PHINode>(IncI->getOperand(1)); if (Phi && Phi->getParent() == L->getHeader()) { - if (isLoopInvariant(IncI->getOperand(0), L, DT)) + if (L->isLoopInvariant(IncI->getOperand(0))) return Phi; } return nullptr; } -/// Return the compare guarding the loop latch, or NULL for unrecognized tests. -static ICmpInst *getLoopTest(Loop *L) { - assert(L->getExitingBlock() && "expected loop exit"); - - BasicBlock *LatchBlock = L->getLoopLatch(); - // Don't bother with LFTR if the loop is not properly simplified. - if (!LatchBlock) - return nullptr; - - BranchInst *BI = dyn_cast<BranchInst>(L->getExitingBlock()->getTerminator()); - assert(BI && "expected exit branch"); +/// Whether the current loop exit test is based on this value. Currently this +/// is limited to a direct use in the loop condition. +static bool isLoopExitTestBasedOn(Value *V, BasicBlock *ExitingBB) { + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); + ICmpInst *ICmp = dyn_cast<ICmpInst>(BI->getCondition()); + // TODO: Allow non-icmp loop test. + if (!ICmp) + return false; - return dyn_cast<ICmpInst>(BI->getCondition()); + // TODO: Allow indirect use. + return ICmp->getOperand(0) == V || ICmp->getOperand(1) == V; } /// linearFunctionTestReplace policy. Return true unless we can show that the /// current exit test is already sufficiently canonical. -static bool needsLFTR(Loop *L, DominatorTree *DT) { +static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) { + assert(L->getLoopLatch() && "Must be in simplified form"); + + // Avoid converting a constant or loop invariant test back to a runtime + // test. This is critical for when SCEV's cached ExitCount is less precise + // than the current IR (such as after we've proven a particular exit is + // actually dead and thus the BE count never reaches our ExitCount.) + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); + if (L->isLoopInvariant(BI->getCondition())) + return false; + // Do LFTR to simplify the exit condition to an ICMP. - ICmpInst *Cond = getLoopTest(L); + ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition()); if (!Cond) return true; @@ -2078,15 +2067,15 @@ static bool needsLFTR(Loop *L, DominatorTree *DT) { // Look for a loop invariant RHS Value *LHS = Cond->getOperand(0); Value *RHS = Cond->getOperand(1); - if (!isLoopInvariant(RHS, L, DT)) { - if (!isLoopInvariant(LHS, L, DT)) + if (!L->isLoopInvariant(RHS)) { + if (!L->isLoopInvariant(LHS)) return true; std::swap(LHS, RHS); } // Look for a simple IV counter LHS PHINode *Phi = dyn_cast<PHINode>(LHS); if (!Phi) - Phi = getLoopPhiForCounter(LHS, L, DT); + Phi = getLoopPhiForCounter(LHS, L); if (!Phi) return true; @@ -2098,7 +2087,49 @@ static bool needsLFTR(Loop *L, DominatorTree *DT) { // Do LFTR if the exit condition's IV is *not* a simple counter. Value *IncV = Phi->getIncomingValue(Idx); - return Phi != getLoopPhiForCounter(IncV, L, DT); + return Phi != getLoopPhiForCounter(IncV, L); +} + +/// Return true if undefined behavior would provable be executed on the path to +/// OnPathTo if Root produced a posion result. Note that this doesn't say +/// anything about whether OnPathTo is actually executed or whether Root is +/// actually poison. This can be used to assess whether a new use of Root can +/// be added at a location which is control equivalent with OnPathTo (such as +/// immediately before it) without introducing UB which didn't previously +/// exist. Note that a false result conveys no information. +static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, + Instruction *OnPathTo, + DominatorTree *DT) { + // Basic approach is to assume Root is poison, propagate poison forward + // through all users we can easily track, and then check whether any of those + // users are provable UB and must execute before out exiting block might + // exit. + + // The set of all recursive users we've visited (which are assumed to all be + // poison because of said visit) + SmallSet<const Value *, 16> KnownPoison; + SmallVector<const Instruction*, 16> Worklist; + Worklist.push_back(Root); + while (!Worklist.empty()) { + const Instruction *I = Worklist.pop_back_val(); + + // If we know this must trigger UB on a path leading our target. + if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo)) + return true; + + // If we can't analyze propagation through this instruction, just skip it + // and transitive users. Safe as false is a conservative result. + if (!propagatesFullPoison(I) && I != Root) + continue; + + if (KnownPoison.insert(I).second) + for (const User *User : I->users()) + Worklist.push_back(cast<Instruction>(User)); + } + + // Might be non-UB, or might have a path we couldn't prove must execute on + // way to exiting bb. + return false; } /// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils @@ -2157,46 +2188,62 @@ static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) { return true; } -/// Find an affine IV in canonical form. +/// Return true if the given phi is a "counter" in L. A counter is an +/// add recurance (of integer or pointer type) with an arbitrary start, and a +/// step of 1. Note that L must have exactly one latch. +static bool isLoopCounter(PHINode* Phi, Loop *L, + ScalarEvolution *SE) { + assert(Phi->getParent() == L->getHeader()); + assert(L->getLoopLatch()); + + if (!SE->isSCEVable(Phi->getType())) + return false; + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); + if (!AR || AR->getLoop() != L || !AR->isAffine()) + return false; + + const SCEV *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); + if (!Step || !Step->isOne()) + return false; + + int LatchIdx = Phi->getBasicBlockIndex(L->getLoopLatch()); + Value *IncV = Phi->getIncomingValue(LatchIdx); + return (getLoopPhiForCounter(IncV, L) == Phi); +} + +/// Search the loop header for a loop counter (anadd rec w/step of one) +/// suitable for use by LFTR. If multiple counters are available, select the +/// "best" one based profitable heuristics. /// /// BECount may be an i8* pointer type. The pointer difference is already /// valid count without scaling the address stride, so it remains a pointer /// expression as far as SCEV is concerned. -/// -/// Currently only valid for LFTR. See the comments on hasConcreteDef below. -/// -/// FIXME: Accept -1 stride and set IVLimit = IVInit - BECount -/// -/// FIXME: Accept non-unit stride as long as SCEV can reduce BECount * Stride. -/// This is difficult in general for SCEV because of potential overflow. But we -/// could at least handle constant BECounts. -static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, +static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, + const SCEV *BECount, ScalarEvolution *SE, DominatorTree *DT) { uint64_t BCWidth = SE->getTypeSizeInBits(BECount->getType()); - Value *Cond = - cast<BranchInst>(L->getExitingBlock()->getTerminator())->getCondition(); + Value *Cond = cast<BranchInst>(ExitingBB->getTerminator())->getCondition(); // Loop over all of the PHI nodes, looking for a simple counter. PHINode *BestPhi = nullptr; const SCEV *BestInit = nullptr; BasicBlock *LatchBlock = L->getLoopLatch(); - assert(LatchBlock && "needsLFTR should guarantee a loop latch"); + assert(LatchBlock && "Must be in simplified form"); const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { PHINode *Phi = cast<PHINode>(I); - if (!SE->isSCEVable(Phi->getType())) + if (!isLoopCounter(Phi, L, SE)) continue; // Avoid comparing an integer IV against a pointer Limit. if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy()) continue; - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); - if (!AR || AR->getLoop() != L || !AR->isAffine()) - continue; - + const auto *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); + // AR may be a pointer type, while BECount is an integer type. // AR may be wider than BECount. With eq/ne tests overflow is immaterial. // AR may not be a narrower type, or we may never exit. @@ -2204,28 +2251,30 @@ static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, if (PhiWidth < BCWidth || !DL.isLegalInteger(PhiWidth)) continue; - const SCEV *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); - if (!Step || !Step->isOne()) - continue; - - int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); - Value *IncV = Phi->getIncomingValue(LatchIdx); - if (getLoopPhiForCounter(IncV, L, DT) != Phi) - continue; - // Avoid reusing a potentially undef value to compute other values that may // have originally had a concrete definition. if (!hasConcreteDef(Phi)) { // We explicitly allow unknown phis as long as they are already used by - // the loop test. In this case we assume that performing LFTR could not - // increase the number of undef users. - if (ICmpInst *Cond = getLoopTest(L)) { - if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L, DT) && - Phi != getLoopPhiForCounter(Cond->getOperand(1), L, DT)) { - continue; - } - } + // the loop exit test. This is legal since performing LFTR could not + // increase the number of undef users. + Value *IncPhi = Phi->getIncomingValueForBlock(LatchBlock); + if (!isLoopExitTestBasedOn(Phi, ExitingBB) && + !isLoopExitTestBasedOn(IncPhi, ExitingBB)) + continue; } + + // Avoid introducing undefined behavior due to poison which didn't exist in + // the original program. (Annoyingly, the rules for poison and undef + // propagation are distinct, so this does NOT cover the undef case above.) + // We have to ensure that we don't introduce UB by introducing a use on an + // iteration where said IV produces poison. Our strategy here differs for + // pointers and integer IVs. For integers, we strip and reinfer as needed, + // see code in linearFunctionTestReplace. For pointers, we restrict + // transforms as there is no good way to reinfer inbounds once lost. + if (!Phi->getType()->isIntegerTy() && + !mustExecuteUBIfPoisonOnPathTo(Phi, ExitingBB->getTerminator(), DT)) + continue; + const SCEV *Init = AR->getStart(); if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { @@ -2251,47 +2300,49 @@ static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, return BestPhi; } -/// Help linearFunctionTestReplace by generating a value that holds the RHS of -/// the new loop test. -static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, +/// Insert an IR expression which computes the value held by the IV IndVar +/// (which must be an loop counter w/unit stride) after the backedge of loop L +/// is taken ExitCount times. +static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB, + const SCEV *ExitCount, bool UsePostInc, Loop *L, SCEVExpander &Rewriter, ScalarEvolution *SE) { - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); - assert(AR && AR->getLoop() == L && AR->isAffine() && "bad loop counter"); + assert(isLoopCounter(IndVar, L, SE)); + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); const SCEV *IVInit = AR->getStart(); - // IVInit may be a pointer while IVCount is an integer when FindLoopCounter - // finds a valid pointer IV. Sign extend BECount in order to materialize a + // IVInit may be a pointer while ExitCount is an integer when FindLoopCounter + // finds a valid pointer IV. Sign extend ExitCount in order to materialize a // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing // the existing GEPs whenever possible. - if (IndVar->getType()->isPointerTy() && !IVCount->getType()->isPointerTy()) { + if (IndVar->getType()->isPointerTy() && + !ExitCount->getType()->isPointerTy()) { // IVOffset will be the new GEP offset that is interpreted by GEP as a - // signed value. IVCount on the other hand represents the loop trip count, + // signed value. ExitCount on the other hand represents the loop trip count, // which is an unsigned value. FindLoopCounter only allows induction // variables that have a positive unit stride of one. This means we don't // have to handle the case of negative offsets (yet) and just need to zero - // extend IVCount. + // extend ExitCount. Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType()); - const SCEV *IVOffset = SE->getTruncateOrZeroExtend(IVCount, OfsTy); + const SCEV *IVOffset = SE->getTruncateOrZeroExtend(ExitCount, OfsTy); + if (UsePostInc) + IVOffset = SE->getAddExpr(IVOffset, SE->getOne(OfsTy)); // Expand the code for the iteration count. assert(SE->isLoopInvariant(IVOffset, L) && "Computed iteration count is not loop invariant!"); - BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); - Value *GEPOffset = Rewriter.expandCodeFor(IVOffset, OfsTy, BI); - Value *GEPBase = IndVar->getIncomingValueForBlock(L->getLoopPreheader()); - assert(AR->getStart() == SE->getSCEV(GEPBase) && "bad loop counter"); // We could handle pointer IVs other than i8*, but we need to compensate for - // gep index scaling. See canExpandBackedgeTakenCount comments. + // gep index scaling. assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()), - cast<PointerType>(GEPBase->getType()) + cast<PointerType>(IndVar->getType()) ->getElementType())->isOne() && "unit stride pointer IV must be i8*"); - IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); - return Builder.CreateGEP(nullptr, GEPBase, GEPOffset, "lftr.limit"); + const SCEV *IVLimit = SE->getAddExpr(IVInit, IVOffset); + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); + return Rewriter.expandCodeFor(IVLimit, IndVar->getType(), BI); } else { - // In any other case, convert both IVInit and IVCount to integers before + // In any other case, convert both IVInit and ExitCount to integers before // comparing. This may result in SCEV expansion of pointers, but in practice // SCEV will fold the pointer arithmetic away as such: // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc). @@ -2299,35 +2350,40 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, // Valid Cases: (1) both integers is most common; (2) both may be pointers // for simple memset-style loops. // - // IVInit integer and IVCount pointer would only occur if a canonical IV + // IVInit integer and ExitCount pointer would only occur if a canonical IV // were generated on top of case #2, which is not expected. - const SCEV *IVLimit = nullptr; - // For unit stride, IVCount = Start + BECount with 2's complement overflow. - // For non-zero Start, compute IVCount here. - if (AR->getStart()->isZero()) - IVLimit = IVCount; - else { - assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride"); - const SCEV *IVInit = AR->getStart(); + assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride"); + // For unit stride, IVCount = Start + ExitCount with 2's complement + // overflow. + + // For integer IVs, truncate the IV before computing IVInit + BECount, + // unless we know apriori that the limit must be a constant when evaluated + // in the bitwidth of the IV. We prefer (potentially) keeping a truncate + // of the IV in the loop over a (potentially) expensive expansion of the + // widened exit count add(zext(add)) expression. + if (SE->getTypeSizeInBits(IVInit->getType()) + > SE->getTypeSizeInBits(ExitCount->getType())) { + if (isa<SCEVConstant>(IVInit) && isa<SCEVConstant>(ExitCount)) + ExitCount = SE->getZeroExtendExpr(ExitCount, IVInit->getType()); + else + IVInit = SE->getTruncateExpr(IVInit, ExitCount->getType()); + } - // For integer IVs, truncate the IV before computing IVInit + BECount. - if (SE->getTypeSizeInBits(IVInit->getType()) - > SE->getTypeSizeInBits(IVCount->getType())) - IVInit = SE->getTruncateExpr(IVInit, IVCount->getType()); + const SCEV *IVLimit = SE->getAddExpr(IVInit, ExitCount); + + if (UsePostInc) + IVLimit = SE->getAddExpr(IVLimit, SE->getOne(IVLimit->getType())); - IVLimit = SE->getAddExpr(IVInit, IVCount); - } // Expand the code for the iteration count. - BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); - IRBuilder<> Builder(BI); assert(SE->isLoopInvariant(IVLimit, L) && "Computed iteration count is not loop invariant!"); // Ensure that we generate the same type as IndVar, or a smaller integer // type. In the presence of null pointer values, we have an integer type // SCEV expression (IVInit) for a pointer type IV value (IndVar). - Type *LimitTy = IVCount->getType()->isPointerTy() ? - IndVar->getType() : IVCount->getType(); + Type *LimitTy = ExitCount->getType()->isPointerTy() ? + IndVar->getType() : ExitCount->getType(); + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); return Rewriter.expandCodeFor(IVLimit, LimitTy, BI); } } @@ -2338,51 +2394,70 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, /// determine a loop-invariant trip count of the loop, which is actually a much /// broader range than just linear tests. bool IndVarSimplify:: -linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, +linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, + const SCEV *ExitCount, PHINode *IndVar, SCEVExpander &Rewriter) { - assert(canExpandBackedgeTakenCount(L, SE, Rewriter) && "precondition"); + assert(L->getLoopLatch() && "Loop no longer in simplified form?"); + assert(isLoopCounter(IndVar, L, SE)); + Instruction * const IncVar = + cast<Instruction>(IndVar->getIncomingValueForBlock(L->getLoopLatch())); - // Initialize CmpIndVar and IVCount to their preincremented values. + // Initialize CmpIndVar to the preincremented IV. Value *CmpIndVar = IndVar; - const SCEV *IVCount = BackedgeTakenCount; - - assert(L->getLoopLatch() && "Loop no longer in simplified form?"); + bool UsePostInc = false; // If the exiting block is the same as the backedge block, we prefer to // compare against the post-incremented value, otherwise we must compare // against the preincremented value. - if (L->getExitingBlock() == L->getLoopLatch()) { - // Add one to the "backedge-taken" count to get the trip count. - // This addition may overflow, which is valid as long as the comparison is - // truncated to BackedgeTakenCount->getType(). - IVCount = SE->getAddExpr(BackedgeTakenCount, - SE->getOne(BackedgeTakenCount->getType())); - // The BackedgeTaken expression contains the number of times that the - // backedge branches to the loop header. This is one less than the - // number of times the loop executes, so use the incremented indvar. - CmpIndVar = IndVar->getIncomingValueForBlock(L->getExitingBlock()); + if (ExitingBB == L->getLoopLatch()) { + // For pointer IVs, we chose to not strip inbounds which requires us not + // to add a potentially UB introducing use. We need to either a) show + // the loop test we're modifying is already in post-inc form, or b) show + // that adding a use must not introduce UB. + bool SafeToPostInc = + IndVar->getType()->isIntegerTy() || + isLoopExitTestBasedOn(IncVar, ExitingBB) || + mustExecuteUBIfPoisonOnPathTo(IncVar, ExitingBB->getTerminator(), DT); + if (SafeToPostInc) { + UsePostInc = true; + CmpIndVar = IncVar; + } } - Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE); + // It may be necessary to drop nowrap flags on the incrementing instruction + // if either LFTR moves from a pre-inc check to a post-inc check (in which + // case the increment might have previously been poison on the last iteration + // only) or if LFTR switches to a different IV that was previously dynamically + // dead (and as such may be arbitrarily poison). We remove any nowrap flags + // that SCEV didn't infer for the post-inc addrec (even if we use a pre-inc + // check), because the pre-inc addrec flags may be adopted from the original + // instruction, while SCEV has to explicitly prove the post-inc nowrap flags. + // TODO: This handling is inaccurate for one case: If we switch to a + // dynamically dead IV that wraps on the first loop iteration only, which is + // not covered by the post-inc addrec. (If the new IV was not dynamically + // dead, it could not be poison on the first iteration in the first place.) + if (auto *BO = dyn_cast<BinaryOperator>(IncVar)) { + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IncVar)); + if (BO->hasNoUnsignedWrap()) + BO->setHasNoUnsignedWrap(AR->hasNoUnsignedWrap()); + if (BO->hasNoSignedWrap()) + BO->setHasNoSignedWrap(AR->hasNoSignedWrap()); + } + + Value *ExitCnt = genLoopLimit( + IndVar, ExitingBB, ExitCount, UsePostInc, L, Rewriter, SE); assert(ExitCnt->getType()->isPointerTy() == IndVar->getType()->isPointerTy() && "genLoopLimit missed a cast"); // Insert a new icmp_ne or icmp_eq instruction before the branch. - BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); + BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator()); ICmpInst::Predicate P; if (L->contains(BI->getSuccessor(0))) P = ICmpInst::ICMP_NE; else P = ICmpInst::ICMP_EQ; - LLVM_DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" - << " LHS:" << *CmpIndVar << '\n' - << " op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==") - << "\n" - << " RHS:\t" << *ExitCnt << "\n" - << " IVCount:\t" << *IVCount << "\n"); - IRBuilder<> Builder(BI); // The new loop exit condition should reuse the debug location of the @@ -2390,67 +2465,58 @@ linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, if (auto *Cond = dyn_cast<Instruction>(BI->getCondition())) Builder.SetCurrentDebugLocation(Cond->getDebugLoc()); - // LFTR can ignore IV overflow and truncate to the width of - // BECount. This avoids materializing the add(zext(add)) expression. + // For integer IVs, if we evaluated the limit in the narrower bitwidth to + // avoid the expensive expansion of the limit expression in the wider type, + // emit a truncate to narrow the IV to the ExitCount type. This is safe + // since we know (from the exit count bitwidth), that we can't self-wrap in + // the narrower type. unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType()); unsigned ExitCntSize = SE->getTypeSizeInBits(ExitCnt->getType()); if (CmpIndVarSize > ExitCntSize) { - const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); - const SCEV *ARStart = AR->getStart(); - const SCEV *ARStep = AR->getStepRecurrence(*SE); - // For constant IVCount, avoid truncation. - if (isa<SCEVConstant>(ARStart) && isa<SCEVConstant>(IVCount)) { - const APInt &Start = cast<SCEVConstant>(ARStart)->getAPInt(); - APInt Count = cast<SCEVConstant>(IVCount)->getAPInt(); - // Note that the post-inc value of BackedgeTakenCount may have overflowed - // above such that IVCount is now zero. - if (IVCount != BackedgeTakenCount && Count == 0) { - Count = APInt::getMaxValue(Count.getBitWidth()).zext(CmpIndVarSize); - ++Count; - } - else - Count = Count.zext(CmpIndVarSize); - APInt NewLimit; - if (cast<SCEVConstant>(ARStep)->getValue()->isNegative()) - NewLimit = Start - Count; - else - NewLimit = Start + Count; - ExitCnt = ConstantInt::get(CmpIndVar->getType(), NewLimit); - - LLVM_DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); + assert(!CmpIndVar->getType()->isPointerTy() && + !ExitCnt->getType()->isPointerTy()); + + // Before resorting to actually inserting the truncate, use the same + // reasoning as from SimplifyIndvar::eliminateTrunc to see if we can extend + // the other side of the comparison instead. We still evaluate the limit + // in the narrower bitwidth, we just prefer a zext/sext outside the loop to + // a truncate within in. + bool Extended = false; + const SCEV *IV = SE->getSCEV(CmpIndVar); + const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar), + ExitCnt->getType()); + const SCEV *ZExtTrunc = + SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType()); + + if (ZExtTrunc == IV) { + Extended = true; + ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), + "wide.trip.count"); } else { - // We try to extend trip count first. If that doesn't work we truncate IV. - // Zext(trunc(IV)) == IV implies equivalence of the following two: - // Trunc(IV) == ExitCnt and IV == zext(ExitCnt). Similarly for sext. If - // one of the two holds, extend the trip count, otherwise we truncate IV. - bool Extended = false; - const SCEV *IV = SE->getSCEV(CmpIndVar); - const SCEV *ZExtTrunc = - SE->getZeroExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), - ExitCnt->getType()), - CmpIndVar->getType()); - - if (ZExtTrunc == IV) { + const SCEV *SExtTrunc = + SE->getSignExtendExpr(TruncatedIV, CmpIndVar->getType()); + if (SExtTrunc == IV) { Extended = true; - ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), + ExitCnt = Builder.CreateSExt(ExitCnt, IndVar->getType(), "wide.trip.count"); - } else { - const SCEV *SExtTrunc = - SE->getSignExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), - ExitCnt->getType()), - CmpIndVar->getType()); - if (SExtTrunc == IV) { - Extended = true; - ExitCnt = Builder.CreateSExt(ExitCnt, IndVar->getType(), - "wide.trip.count"); - } } - - if (!Extended) - CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), - "lftr.wideiv"); } + + if (Extended) { + bool Discard; + L->makeLoopInvariant(ExitCnt, Discard); + } else + CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), + "lftr.wideiv"); } + LLVM_DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" + << " LHS:" << *CmpIndVar << '\n' + << " op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==") + << "\n" + << " RHS:\t" << *ExitCnt << "\n" + << "ExitCount:\t" << *ExitCount << "\n" + << " was: " << *BI->getCondition() << "\n"); + Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond"); Value *OrigCond = BI->getCondition(); // It's tempting to use replaceAllUsesWith here to fully replace the old @@ -2558,6 +2624,111 @@ bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { return MadeAnyChanges; } +bool IndVarSimplify::optimizeLoopExits(Loop *L) { + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Form an expression for the maximum exit count possible for this loop. We + // merge the max and exact information to approximate a version of + // getMaxBackedgeTakenInfo which isn't restricted to just constants. + // TODO: factor this out as a version of getMaxBackedgeTakenCount which + // isn't guaranteed to return a constant. + SmallVector<const SCEV*, 4> ExitCounts; + const SCEV *MaxConstEC = SE->getMaxBackedgeTakenCount(L); + if (!isa<SCEVCouldNotCompute>(MaxConstEC)) + ExitCounts.push_back(MaxConstEC); + for (BasicBlock *ExitingBB : ExitingBlocks) { + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + if (!isa<SCEVCouldNotCompute>(ExitCount)) { + assert(DT->dominates(ExitingBB, L->getLoopLatch()) && + "We should only have known counts for exiting blocks that " + "dominate latch!"); + ExitCounts.push_back(ExitCount); + } + } + if (ExitCounts.empty()) + return false; + const SCEV *MaxExitCount = SE->getUMinFromMismatchedTypes(ExitCounts); + + bool Changed = false; + for (BasicBlock *ExitingBB : ExitingBlocks) { + // If our exitting block exits multiple loops, we can only rewrite the + // innermost one. Otherwise, we're changing how many times the innermost + // loop runs before it exits. + if (LI->getLoopFor(ExitingBB) != L) + continue; + + // Can't rewrite non-branch yet. + BranchInst *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + continue; + + // If already constant, nothing to do. + if (isa<Constant>(BI->getCondition())) + continue; + + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + if (isa<SCEVCouldNotCompute>(ExitCount)) + continue; + + // If we know we'd exit on the first iteration, rewrite the exit to + // reflect this. This does not imply the loop must exit through this + // exit; there may be an earlier one taken on the first iteration. + // TODO: Given we know the backedge can't be taken, we should go ahead + // and break it. Or at least, kill all the header phis and simplify. + if (ExitCount->isZero()) { + bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); + auto *OldCond = BI->getCondition(); + auto *NewCond = ExitIfTrue ? ConstantInt::getTrue(OldCond->getType()) : + ConstantInt::getFalse(OldCond->getType()); + BI->setCondition(NewCond); + if (OldCond->use_empty()) + DeadInsts.push_back(OldCond); + Changed = true; + continue; + } + + // If we end up with a pointer exit count, bail. + if (!ExitCount->getType()->isIntegerTy() || + !MaxExitCount->getType()->isIntegerTy()) + return false; + + Type *WiderType = + SE->getWiderType(MaxExitCount->getType(), ExitCount->getType()); + ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType); + MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType); + assert(MaxExitCount->getType() == ExitCount->getType()); + + // Can we prove that some other exit must be taken strictly before this + // one? TODO: handle cases where ule is known, and equality is covered + // by a dominating exit + if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, + MaxExitCount, ExitCount)) { + bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); + auto *OldCond = BI->getCondition(); + auto *NewCond = ExitIfTrue ? ConstantInt::getFalse(OldCond->getType()) : + ConstantInt::getTrue(OldCond->getType()); + BI->setCondition(NewCond); + if (OldCond->use_empty()) + DeadInsts.push_back(OldCond); + Changed = true; + continue; + } + + // TODO: If we can prove that the exiting iteration is equal to the exit + // count for this exit and that no previous exit oppurtunities exist within + // the loop, then we can discharge all other exits. (May fall out of + // previous TODO.) + + // TODO: If we can't prove any relation between our exit count and the + // loops exit count, but taking this exit doesn't require actually running + // the loop (i.e. no side effects, no computed values used in exit), then + // we can replace the exit test with a loop invariant test which exits on + // the first iteration. + } + return Changed; +} + //===----------------------------------------------------------------------===// // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// @@ -2614,23 +2785,60 @@ bool IndVarSimplify::run(Loop *L) { // Eliminate redundant IV cycles. NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); + Changed |= optimizeLoopExits(L); + // If we have a trip count expression, rewrite the loop's exit condition - // using it. We can currently only handle loops with a single exit. - if (!DisableLFTR && canExpandBackedgeTakenCount(L, SE, Rewriter) && - needsLFTR(L, DT)) { - PHINode *IndVar = FindLoopCounter(L, BackedgeTakenCount, SE, DT); - if (IndVar) { + // using it. + if (!DisableLFTR) { + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (BasicBlock *ExitingBB : ExitingBlocks) { + // Can't rewrite non-branch yet. + if (!isa<BranchInst>(ExitingBB->getTerminator())) + continue; + + // If our exitting block exits multiple loops, we can only rewrite the + // innermost one. Otherwise, we're changing how many times the innermost + // loop runs before it exits. + if (LI->getLoopFor(ExitingBB) != L) + continue; + + if (!needsLFTR(L, ExitingBB)) + continue; + + const SCEV *ExitCount = SE->getExitCount(L, ExitingBB); + if (isa<SCEVCouldNotCompute>(ExitCount)) + continue; + + // This was handled above, but as we form SCEVs, we can sometimes refine + // existing ones; this allows exit counts to be folded to zero which + // weren't when optimizeLoopExits saw them. Arguably, we should iterate + // until stable to handle cases like this better. + if (ExitCount->isZero()) + continue; + + PHINode *IndVar = FindLoopCounter(L, ExitingBB, ExitCount, SE, DT); + if (!IndVar) + continue; + + // Avoid high cost expansions. Note: This heuristic is questionable in + // that our definition of "high cost" is not exactly principled. + if (Rewriter.isHighCostExpansion(ExitCount, L)) + continue; + // Check preconditions for proper SCEVExpander operation. SCEV does not - // express SCEVExpander's dependencies, such as LoopSimplify. Instead any - // pass that uses the SCEVExpander must do it. This does not work well for - // loop passes because SCEVExpander makes assumptions about all loops, - // while LoopPassManager only forces the current loop to be simplified. + // express SCEVExpander's dependencies, such as LoopSimplify. Instead + // any pass that uses the SCEVExpander must do it. This does not work + // well for loop passes because SCEVExpander makes assumptions about + // all loops, while LoopPassManager only forces the current loop to be + // simplified. // // FIXME: SCEV expansion has no way to bail out, so the caller must // explicitly check any assumptions made by SCEV. Brittle. - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(BackedgeTakenCount); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ExitCount); if (!AR || AR->getLoop()->getLoopPreheader()) - Changed |= linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, + Changed |= linearFunctionTestReplace(L, ExitingBB, + ExitCount, IndVar, Rewriter); } } diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 1c701bbee185..997d68838152 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -1,9 +1,8 @@ //===- InductiveRangeCheckElimination.cpp - -------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -116,6 +115,11 @@ static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", cl::Hidden, cl::init(true)); +static cl::opt<bool> AllowNarrowLatchCondition( + "irce-allow-narrow-latch", cl::Hidden, cl::init(true), + cl::desc("If set to true, IRCE may eliminate wide range checks in loops " + "with narrow latch condition.")); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -532,12 +536,6 @@ class LoopConstrainer { Optional<const SCEV *> HighLimit; }; - // A utility function that does a `replaceUsesOfWith' on the incoming block - // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's - // incoming block list with `ReplaceBy'. - static void replacePHIBlock(PHINode *PN, BasicBlock *Block, - BasicBlock *ReplaceBy); - // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. // Return None if unable to compute the set of subranges. @@ -639,13 +637,6 @@ public: } // end anonymous namespace -void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, - BasicBlock *ReplaceBy) { - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingBlock(i) == Block) - PN->setIncomingBlock(i, ReplaceBy); -} - /// Given a loop with an deccreasing induction variable, is it possible to /// safely calculate the bounds of a new loop using the given Predicate. static bool isSafeDecreasingBound(const SCEV *Start, @@ -868,7 +859,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, assert(!StepCI->isZero() && "Zero step?"); bool IsIncreasing = !StepCI->isNegative(); - bool IsSignedPredicate = ICmpInst::isSigned(Pred); + bool IsSignedPredicate; const SCEV *StartNext = IndVarBase->getStart(); const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); @@ -1045,11 +1036,23 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return Result; } +/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return +/// signed or unsigned extension of \p S to type \p Ty. +static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, + bool Signed) { + return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); +} + Optional<LoopConstrainer::SubRanges> LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); - if (Range.getType() != Ty) + auto *RTy = cast<IntegerType>(Range.getType()); + + // We only support wide range checks and narrow latches. + if (!AllowNarrowLatchCondition && RTy != Ty) + return None; + if (RTy->getBitWidth() < Ty->getBitWidth()) return None; LoopConstrainer::SubRanges Result; @@ -1057,8 +1060,10 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. - const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); - const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); + const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart), + RTy, SE, IsSignedPredicate); + const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy, + SE, IsSignedPredicate); bool Increasing = MainLoopStructure.IndVarIncreasing; @@ -1068,7 +1073,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; - const SCEV *One = SE.getOne(Ty); + const SCEV *One = SE.getOne(RTy); if (Increasing) { Smallest = Start; Greatest = End; @@ -1257,6 +1262,13 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); + auto *RangeTy = Range.getBegin()->getType(); + auto NoopOrExt = [&](Value *V) { + if (V->getType() == RangeTy) + return V; + return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) + : B.CreateZExt(V, RangeTy, "wide." + V->getName()); + }; // EnterLoopCond - is it okay to start executing this `LS'? Value *EnterLoopCond = nullptr; @@ -1264,15 +1276,16 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( Increasing ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - EnterLoopCond = B.CreateICmp(Pred, LS.IndVarStart, ExitSubloopAt); + Value *IndVarStart = NoopOrExt(LS.IndVarStart); + EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); - Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, LS.IndVarBase, - ExitSubloopAt); + Value *IndVarBase = NoopOrExt(LS.IndVarBase); + Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond @@ -1285,7 +1298,8 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = B.CreateICmp(Pred, LS.IndVarBase, LS.LoopExitAt); + Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); + Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -1304,15 +1318,14 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } - RRI.IndVarEnd = PHINode::Create(LS.IndVarBase->getType(), 2, "indvar.end", + RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", BranchToContinuation); - RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); - RRI.IndVarEnd->addIncoming(LS.IndVarBase, RRI.ExitSelector); + RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. - for (PHINode &PN : LS.LatchExit->phis()) - replacePHIBlock(&PN, LS.Latch, RRI.ExitSelector); + LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); return RRI; } @@ -1322,9 +1335,8 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( const LoopConstrainer::RewrittenRangeInfo &RRI) const { unsigned PHIIndex = 0; for (PHINode &PN : LS.Header->phis()) - for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) - if (PN.getIncomingBlock(i) == ContinuationBlock) - PN.setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); + PN.setIncomingValueForBlock(ContinuationBlock, + RRI.PHIValuesAtPseudoExit[PHIIndex++]); LS.IndVarStart = RRI.IndVarEnd; } @@ -1335,9 +1347,7 @@ BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); - for (PHINode &PN : LS.Header->phis()) - for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) - replacePHIBlock(&PN, OldPreheader, Preheader); + LS.Header->replacePhiUsesWith(OldPreheader, Preheader); return Preheader; } @@ -1393,7 +1403,7 @@ bool LoopConstrainer::run() { SubRanges SR = MaybeSR.getValue(); bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = - cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); + cast<IntegerType>(Range.getBegin()->getType()); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1534,7 +1544,7 @@ bool LoopConstrainer::run() { // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { formLCSSARecursively(*L, DT, &LI, &SE); - simplifyLoop(L, &DT, &LI, &SE, nullptr, true); + simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); // Pre/post loops are slow paths, we do not need to perform any loop // optimizations on them. if (!IsOriginalLoop) @@ -1556,6 +1566,12 @@ Optional<InductiveRangeCheck::Range> InductiveRangeCheck::computeSafeIterationSpace( ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, bool IsLatchSigned) const { + // We can deal when types of latch check and range checks don't match in case + // if latch check is more narrow. + auto *IVType = cast<IntegerType>(IndVar->getType()); + auto *RCType = cast<IntegerType>(getBegin()->getType()); + if (IVType->getBitWidth() > RCType->getBitWidth()) + return None; // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1579,8 +1595,9 @@ InductiveRangeCheck::computeSafeIterationSpace( if (!IndVar->isAffine()) return None; - const SCEV *A = IndVar->getStart(); - const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); + const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); + const SCEVConstant *B = dyn_cast<SCEVConstant>( + NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned)); if (!B) return None; assert(!B->isZero() && "Recurrence with zero step?"); @@ -1591,7 +1608,7 @@ InductiveRangeCheck::computeSafeIterationSpace( return None; assert(!D->getValue()->isZero() && "Recurrence with zero step?"); - unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + unsigned BitWidth = RCType->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); // Subtract Y from X so that it does not go through border of the IV diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp index fbbc09eb487f..5f0e2001c73d 100644 --- a/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -1,9 +1,8 @@ //===- InferAddressSpace.cpp - --------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -149,7 +148,9 @@ class InferAddressSpaces : public FunctionPass { public: static char ID; - InferAddressSpaces() : FunctionPass(ID) {} + InferAddressSpaces() : + FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) {} + InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) {} void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); @@ -218,13 +219,17 @@ static bool isAddressExpression(const Value &V) { if (!isa<Operator>(V)) return false; - switch (cast<Operator>(V).getOpcode()) { + const Operator &Op = cast<Operator>(V); + switch (Op.getOpcode()) { case Instruction::PHI: + assert(Op.getType()->isPointerTy()); + return true; case Instruction::BitCast: case Instruction::AddrSpaceCast: case Instruction::GetElementPtr: - case Instruction::Select: return true; + case Instruction::Select: + return Op.getType()->isPointerTy(); default: return false; } @@ -548,10 +553,17 @@ static Value *cloneConstantExprWithNewAddressSpace( if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) { IsNew = true; NewOperands.push_back(cast<Constant>(NewOperand)); - } else { - // Otherwise, reuses the old operand. - NewOperands.push_back(Operand); + continue; } + if (auto CExpr = dyn_cast<ConstantExpr>(Operand)) + if (Value *NewOperand = cloneConstantExprWithNewAddressSpace( + CExpr, NewAddrSpace, ValueWithNewAddrSpace)) { + IsNew = true; + NewOperands.push_back(cast<Constant>(NewOperand)); + continue; + } + // Otherwise, reuses the old operand. + NewOperands.push_back(Operand); } // If !IsNew, we will replace the Value with itself. However, replaced values @@ -621,9 +633,12 @@ bool InferAddressSpaces::runOnFunction(Function &F) { const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - FlatAddrSpace = TTI.getFlatAddressSpace(); - if (FlatAddrSpace == UninitializedAddressSpace) - return false; + + if (FlatAddrSpace == UninitializedAddressSpace) { + FlatAddrSpace = TTI.getFlatAddressSpace(); + if (FlatAddrSpace == UninitializedAddressSpace) + return false; + } // Collects all flat address expressions in postorder. std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F); @@ -991,8 +1006,12 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( } // Otherwise, replaces the use with flat(NewV). - if (Instruction *I = dyn_cast<Instruction>(V)) { - BasicBlock::iterator InsertPos = std::next(I->getIterator()); + if (Instruction *Inst = dyn_cast<Instruction>(V)) { + // Don't create a copy of the original addrspacecast. + if (U == V && isa<AddrSpaceCastInst>(V)) + continue; + + BasicBlock::iterator InsertPos = std::next(Inst->getIterator()); while (isa<PHINode>(InsertPos)) ++InsertPos; U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); @@ -1015,6 +1034,6 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( return true; } -FunctionPass *llvm::createInferAddressSpacesPass() { - return new InferAddressSpaces(); +FunctionPass *llvm::createInferAddressSpacesPass(unsigned AddressSpace) { + return new InferAddressSpaces(AddressSpace); } diff --git a/lib/Transforms/Scalar/InstSimplifyPass.cpp b/lib/Transforms/Scalar/InstSimplifyPass.cpp index 05cd48d83267..6616364ab203 100644 --- a/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -1,9 +1,8 @@ //===- InstSimplifyPass.cpp -----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index 48de56a02834..b86bf2fefbe5 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -1,9 +1,8 @@ //===- JumpThreading.cpp - Thread control through conditional blocks ------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -24,6 +23,7 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -38,7 +38,6 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -103,6 +102,12 @@ static cl::opt<bool> PrintLVIAfterJumpThreading( cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), cl::Hidden); +static cl::opt<bool> ThreadAcrossLoopHeaders( + "jump-threading-across-loop-headers", + cl::desc("Allow JumpThreading to thread across loop headers, for testing"), + cl::init(false), cl::Hidden); + + namespace { /// This pass performs 'jump threading', which looks at blocks that have @@ -369,7 +374,8 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, if (!DT.isReachableFromEntry(&BB)) Unreachable.insert(&BB); - FindLoopHeaders(F); + if (!ThreadAcrossLoopHeaders) + FindLoopHeaders(F); bool EverChanged = false; bool Changed; @@ -1056,7 +1062,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { Condition = IB->getAddress()->stripPointerCasts(); Preference = WantBlockAddress; } else { - return false; // Must be an invoke. + return false; // Must be an invoke or callbr. } // Run constant folding to see if we can reduce the condition to a simple @@ -1092,7 +1098,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { << "' folding undef terminator: " << *BBTerm << '\n'); BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm); BBTerm->eraseFromParent(); - DTU->applyUpdates(Updates); + DTU->applyUpdatesPermissive(Updates); return true; } @@ -1143,7 +1149,9 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { unsigned ToKeep = Ret == LazyValueInfo::True ? 0 : 1; BasicBlock *ToRemoveSucc = CondBr->getSuccessor(ToRemove); ToRemoveSucc->removePredecessor(BB, true); - BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); + BranchInst *UncondBr = + BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); + UncondBr->setDebugLoc(CondBr->getDebugLoc()); CondBr->eraseFromParent(); if (CondCmp->use_empty()) CondCmp->eraseFromParent(); @@ -1160,7 +1168,8 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { ConstantInt::getFalse(CondCmp->getType()); ReplaceFoldableUses(CondCmp, CI); } - DTU->deleteEdgeRelaxed(BB, ToRemoveSucc); + DTU->applyUpdatesPermissive( + {{DominatorTree::Delete, BB, ToRemoveSucc}}); return true; } @@ -1172,7 +1181,8 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { } if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) - TryToUnfoldSelect(SI, BB); + if (TryToUnfoldSelect(SI, BB)) + return true; // Check for some cases that are worth simplifying. Right now we want to look // for loads that are used by a switch or by the condition for the branch. If @@ -1245,9 +1255,10 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { BasicBlock *KeepSucc = BI->getSuccessor(*Implication ? 0 : 1); BasicBlock *RemoveSucc = BI->getSuccessor(*Implication ? 1 : 0); RemoveSucc->removePredecessor(BB); - BranchInst::Create(KeepSucc, BI); + BranchInst *UncondBI = BranchInst::Create(KeepSucc, BI); + UncondBI->setDebugLoc(BI->getDebugLoc()); BI->eraseFromParent(); - DTU->deleteEdgeRelaxed(BB, RemoveSucc); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, RemoveSucc}}); return true; } CurrentBB = CurrentPred; @@ -1429,7 +1440,9 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { // Add all the unavailable predecessors to the PredsToSplit list. for (BasicBlock *P : predecessors(LoadBB)) { // If the predecessor is an indirect goto, we can't split the edge. - if (isa<IndirectBrInst>(P->getTerminator())) + // Same for CallBr. + if (isa<IndirectBrInst>(P->getTerminator()) || + isa<CallBrInst>(P->getTerminator())) return false; if (!AvailablePredSet.count(P)) @@ -1446,11 +1459,11 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { if (UnavailablePred) { assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && "Can't handle critical edge here!"); - LoadInst *NewVal = - new LoadInst(LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), - LoadI->getName() + ".pr", false, LoadI->getAlignment(), - LoadI->getOrdering(), LoadI->getSyncScopeID(), - UnavailablePred->getTerminator()); + LoadInst *NewVal = new LoadInst( + LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), + LoadI->getName() + ".pr", false, LoadI->getAlignment(), + LoadI->getOrdering(), LoadI->getSyncScopeID(), + UnavailablePred->getTerminator()); NewVal->setDebugLoc(LoadI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); @@ -1474,8 +1487,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { for (pred_iterator PI = PB; PI != PE; ++PI) { BasicBlock *P = *PI; AvailablePredsTy::iterator I = - std::lower_bound(AvailablePreds.begin(), AvailablePreds.end(), - std::make_pair(P, (Value*)nullptr)); + llvm::lower_bound(AvailablePreds, std::make_pair(P, (Value *)nullptr)); assert(I != AvailablePreds.end() && I->first == P && "Didn't find entry for predecessor!"); @@ -1601,7 +1613,6 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, Constant *OnlyVal = nullptr; Constant *MultipleVal = (Constant *)(intptr_t)~0ULL; - unsigned PredWithKnownDest = 0; for (const auto &PredValue : PredValues) { BasicBlock *Pred = PredValue.second; if (!SeenPreds.insert(Pred).second) @@ -1638,12 +1649,10 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, OnlyVal = MultipleVal; } - // We know where this predecessor is going. - ++PredWithKnownDest; - // If the predecessor ends with an indirect goto, we can't change its - // destination. - if (isa<IndirectBrInst>(Pred->getTerminator())) + // destination. Same for CallBr. + if (isa<IndirectBrInst>(Pred->getTerminator()) || + isa<CallBrInst>(Pred->getTerminator())) continue; PredToDestList.push_back(std::make_pair(Pred, DestBB)); @@ -1657,7 +1666,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, // not thread. By doing so, we do not need to duplicate the current block and // also miss potential opportunities in case we dont/cant duplicate. if (OnlyDest && OnlyDest != MultipleDestSentinel) { - if (PredWithKnownDest == (size_t)pred_size(BB)) { + if (BB->hasNPredecessors(PredToDestList.size())) { bool SeenFirstBranchToOnlyDest = false; std::vector <DominatorTree::UpdateType> Updates; Updates.reserve(BB->getTerminator()->getNumSuccessors() - 1); @@ -1674,7 +1683,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, Instruction *Term = BB->getTerminator(); BranchInst::Create(OnlyDest, Term); Term->eraseFromParent(); - DTU->applyUpdates(Updates); + DTU->applyUpdatesPermissive(Updates); // If the condition is now dead due to the removal of the old terminator, // erase it. @@ -1976,8 +1985,14 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, } BasicBlock::iterator BI = BB->begin(); - for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) - ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + // Clone the phi nodes of BB into NewBB. The resulting phi nodes are trivial, + // since NewBB only has one predecessor, but SSAUpdater might need to rewrite + // the operand of the cloned phi. + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) { + PHINode *NewPN = PHINode::Create(PN->getType(), 1, PN->getName(), NewBB); + NewPN->addIncoming(PN->getIncomingValueForBlock(PredBB), PredBB); + ValueMapping[PN] = NewPN; + } // Clone the non-phi instructions of BB into NewBB, keeping track of the // mapping and using it to remap operands in the cloned instructions. @@ -2016,9 +2031,9 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, } // Enqueue required DT updates. - DTU->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}, - {DominatorTree::Insert, PredBB, NewBB}, - {DominatorTree::Delete, PredBB, BB}}); + DTU->applyUpdatesPermissive({{DominatorTree::Insert, NewBB, SuccBB}, + {DominatorTree::Insert, PredBB, NewBB}, + {DominatorTree::Delete, PredBB, BB}}); // If there were values defined in BB that are used outside the block, then we // now have to update all uses of the value to use either the original value, @@ -2112,7 +2127,7 @@ BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } - DTU->applyUpdates(Updates); + DTU->applyUpdatesPermissive(Updates); return NewBBs[0]; } @@ -2385,7 +2400,7 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // Remove the unconditional branch at the end of the PredBB block. OldPredBranch->eraseFromParent(); - DTU->applyUpdates(Updates); + DTU->applyUpdatesPermissive(Updates); ++NumDupes; return true; @@ -2421,8 +2436,8 @@ void JumpThreadingPass::UnfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, // The select is now dead. SI->eraseFromParent(); - DTU->applyUpdates({{DominatorTree::Insert, NewBB, BB}, - {DominatorTree::Insert, Pred, NewBB}}); + DTU->applyUpdatesPermissive({{DominatorTree::Insert, NewBB, BB}, + {DominatorTree::Insert, Pred, NewBB}}); // Update any other PHI nodes in BB. for (BasicBlock::iterator BI = BB->begin(); @@ -2599,7 +2614,7 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { Updates.push_back({DominatorTree::Delete, BB, Succ}); Updates.push_back({DominatorTree::Insert, SplitBB, Succ}); } - DTU->applyUpdates(Updates); + DTU->applyUpdatesPermissive(Updates); return true; } return false; diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index d204654c3915..d9dda4cef2d2 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -1,9 +1,8 @@ //===-- LICM.cpp - Loop Invariant Code Motion Pass ------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -55,6 +54,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -107,17 +107,29 @@ static cl::opt<int> LICMN2Theshold("licm-n2-threshold", cl::Hidden, cl::init(0), cl::desc("How many instruction to cross product using AA")); -// Experimental option to allow imprecision in LICM (use MemorySSA cap) in -// pathological cases, in exchange for faster compile. This is to be removed -// if MemorySSA starts to address the same issue. This flag applies only when -// LICM uses MemorySSA instead on AliasSetTracker. When the flag is disabled -// (default), LICM calls MemorySSAWalker's getClobberingMemoryAccess, which -// gets perfect accuracy. When flag is enabled, LICM will call into MemorySSA's -// getDefiningAccess, which may not be precise, since optimizeUses is capped. -static cl::opt<bool> EnableLicmCap( - "enable-licm-cap", cl::init(false), cl::Hidden, - cl::desc("Enable imprecision in LICM (uses MemorySSA cap) in " - "pathological cases, in exchange for faster compile")); +// Experimental option to allow imprecision in LICM in pathological cases, in +// exchange for faster compile. This is to be removed if MemorySSA starts to +// address the same issue. This flag applies only when LICM uses MemorySSA +// instead on AliasSetTracker. LICM calls MemorySSAWalker's +// getClobberingMemoryAccess, up to the value of the Cap, getting perfect +// accuracy. Afterwards, LICM will call into MemorySSA's getDefiningAccess, +// which may not be precise, since optimizeUses is capped. The result is +// correct, but we may not get as "far up" as possible to get which access is +// clobbering the one queried. +cl::opt<unsigned> llvm::SetLicmMssaOptCap( + "licm-mssa-optimization-cap", cl::init(100), cl::Hidden, + cl::desc("Enable imprecision in LICM in pathological cases, in exchange " + "for faster compile. Caps the MemorySSA clobbering calls.")); + +// Experimentally, memory promotion carries less importance than sinking and +// hoisting. Limit when we do promotion when using MemorySSA, in order to save +// compile time. +cl::opt<unsigned> llvm::SetLicmMssaNoAccForPromotionCap( + "licm-mssa-max-acc-promotion", cl::init(250), cl::Hidden, + cl::desc("[LICM & MemorySSA] When MSSA in LICM is disabled, this has no " + "effect. When MSSA in LICM is enabled, then this is the maximum " + "number of accesses allowed to be present in a loop in order to " + "enable memory promotion.")); static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, @@ -128,8 +140,7 @@ static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE, - bool FreeInLoop); + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE); static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, const Loop *CurLoop, @@ -140,7 +151,8 @@ static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, AliasSetTracker *CurAST, Loop *CurLoop, AliasAnalysis *AA); static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, - Loop *CurLoop); + Loop *CurLoop, + SinkAndHoistLICMFlags &Flags); static Instruction *CloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); @@ -149,7 +161,8 @@ static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, AliasSetTracker *AST, MemorySSAUpdater *MSSAU); static void moveInstructionBefore(Instruction &I, Instruction &Dest, - ICFLoopSafetyInfo &SafetyInfo); + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater *MSSAU); namespace { struct LoopInvariantCodeMotion { @@ -160,17 +173,29 @@ struct LoopInvariantCodeMotion { OptimizationRemarkEmitter *ORE, bool DeleteAST); ASTrackerMapTy &getLoopToAliasSetMap() { return LoopToAliasSetMap; } + LoopInvariantCodeMotion(unsigned LicmMssaOptCap, + unsigned LicmMssaNoAccForPromotionCap) + : LicmMssaOptCap(LicmMssaOptCap), + LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap) {} private: ASTrackerMapTy LoopToAliasSetMap; + unsigned LicmMssaOptCap; + unsigned LicmMssaNoAccForPromotionCap; std::unique_ptr<AliasSetTracker> collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA); + std::unique_ptr<AliasSetTracker> + collectAliasInfoForLoopWithMSSA(Loop *L, AliasAnalysis *AA, + MemorySSAUpdater *MSSAU); }; struct LegacyLICMPass : public LoopPass { static char ID; // Pass identification, replacement for typeid - LegacyLICMPass() : LoopPass(ID) { + LegacyLICMPass( + unsigned LicmMssaOptCap = SetLicmMssaOptCap, + unsigned LicmMssaNoAccForPromotionCap = SetLicmMssaNoAccForPromotionCap) + : LoopPass(ID), LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap) { initializeLegacyLICMPassPass(*PassRegistry::getPassRegistry()); } @@ -219,8 +244,16 @@ struct LegacyLICMPass : public LoopPass { using llvm::Pass::doFinalization; bool doFinalization() override { - assert(LICM.getLoopToAliasSetMap().empty() && + auto &AliasSetMap = LICM.getLoopToAliasSetMap(); + // All loops in the AliasSetMap should be cleaned up already. The only case + // where we fail to do so is if an outer loop gets deleted before LICM + // visits it. + assert(all_of(AliasSetMap, + [](LoopInvariantCodeMotion::ASTrackerMapTy::value_type &KV) { + return !KV.first->getParentLoop(); + }) && "Didn't free loop alias sets"); + AliasSetMap.clear(); return false; } @@ -252,7 +285,7 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not " "cached at a higher level"); - LoopInvariantCodeMotion LICM; + LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap); if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE, AR.MSSA, ORE, true)) return PreservedAnalyses::all(); @@ -261,6 +294,8 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LoopAnalysis>(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -276,6 +311,10 @@ INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, false) Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } +Pass *llvm::createLICMPass(unsigned LicmMssaOptCap, + unsigned LicmMssaNoAccForPromotionCap) { + return new LegacyLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap); +} /// Hoist expressions out of the specified loop. Note, alias info for inner /// loop is not preserved so it is not a good idea to run LICM multiple @@ -293,12 +332,31 @@ bool LoopInvariantCodeMotion::runOnLoop( std::unique_ptr<AliasSetTracker> CurAST; std::unique_ptr<MemorySSAUpdater> MSSAU; + bool NoOfMemAccTooLarge = false; + unsigned LicmMssaOptCounter = 0; + if (!MSSA) { LLVM_DEBUG(dbgs() << "LICM: Using Alias Set Tracker.\n"); CurAST = collectAliasInfoForLoop(L, LI, AA); } else { - LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA. Promotion disabled.\n"); + LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA.\n"); MSSAU = make_unique<MemorySSAUpdater>(MSSA); + + unsigned AccessCapCount = 0; + for (auto *BB : L->getBlocks()) { + if (auto *Accesses = MSSA->getBlockAccesses(BB)) { + for (const auto &MA : *Accesses) { + (void)MA; + AccessCapCount++; + if (AccessCapCount > LicmMssaNoAccForPromotionCap) { + NoOfMemAccTooLarge = true; + break; + } + } + } + if (NoOfMemAccTooLarge) + break; + } } // Get the preheader block to move instructions into... @@ -317,13 +375,16 @@ bool LoopInvariantCodeMotion::runOnLoop( // that we are guaranteed to see definitions before we see uses. This allows // us to sink instructions in one pass, without iteration. After sinking // instructions, we perform another pass to hoist them out of the loop. - // + SinkAndHoistLICMFlags Flags = {NoOfMemAccTooLarge, LicmMssaOptCounter, + LicmMssaOptCap, LicmMssaNoAccForPromotionCap, + /*IsSink=*/true}; if (L->hasDedicatedExits()) Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, - CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); + CurAST.get(), MSSAU.get(), &SafetyInfo, Flags, ORE); + Flags.IsSink = false; if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, - CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); + CurAST.get(), MSSAU.get(), &SafetyInfo, Flags, ORE); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. @@ -332,7 +393,8 @@ bool LoopInvariantCodeMotion::runOnLoop( // make sure we catch that. An additional load may be generated in the // preheader for SSA updater, so also avoid sinking when no preheader // is available. - if (!DisablePromotion && Preheader && L->hasDedicatedExits()) { + if (!DisablePromotion && Preheader && L->hasDedicatedExits() && + !NoOfMemAccTooLarge) { // Figure out the loop exits and their insertion points SmallVector<BasicBlock *, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); @@ -344,38 +406,45 @@ bool LoopInvariantCodeMotion::runOnLoop( if (!HasCatchSwitch) { SmallVector<Instruction *, 8> InsertPts; + SmallVector<MemoryAccess *, 8> MSSAInsertPts; InsertPts.reserve(ExitBlocks.size()); - for (BasicBlock *ExitBlock : ExitBlocks) + if (MSSAU) + MSSAInsertPts.reserve(ExitBlocks.size()); + for (BasicBlock *ExitBlock : ExitBlocks) { InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + if (MSSAU) + MSSAInsertPts.push_back(nullptr); + } PredIteratorCache PIC; bool Promoted = false; - if (CurAST.get()) { - // Loop over all of the alias sets in the tracker object. - for (AliasSet &AS : *CurAST) { - // We can promote this alias set if it has a store, if it is a "Must" - // alias set, if the pointer is loop invariant, and if we are not - // eliminating any volatile loads or stores. - if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || - !L->isLoopInvariant(AS.begin()->getValue())) - continue; - - assert( - !AS.empty() && - "Must alias set should have at least one pointer element in it!"); - - SmallSetVector<Value *, 8> PointerMustAliases; - for (const auto &ASI : AS) - PointerMustAliases.insert(ASI.getValue()); - - Promoted |= promoteLoopAccessesToScalars( - PointerMustAliases, ExitBlocks, InsertPts, PIC, LI, DT, TLI, L, - CurAST.get(), &SafetyInfo, ORE); - } + // Build an AST using MSSA. + if (!CurAST.get()) + CurAST = collectAliasInfoForLoopWithMSSA(L, AA, MSSAU.get()); + + // Loop over all of the alias sets in the tracker object. + for (AliasSet &AS : *CurAST) { + // We can promote this alias set if it has a store, if it is a "Must" + // alias set, if the pointer is loop invariant, and if we are not + // eliminating any volatile loads or stores. + if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || + !L->isLoopInvariant(AS.begin()->getValue())) + continue; + + assert( + !AS.empty() && + "Must alias set should have at least one pointer element in it!"); + + SmallSetVector<Value *, 8> PointerMustAliases; + for (const auto &ASI : AS) + PointerMustAliases.insert(ASI.getValue()); + + Promoted |= promoteLoopAccessesToScalars( + PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, LI, + DT, TLI, L, CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); } - // FIXME: Promotion initially disabled when using MemorySSA. // Once we have promoted values across the loop body we have to // recursively reform LCSSA as any nested loop may now have values defined @@ -399,7 +468,7 @@ bool LoopInvariantCodeMotion::runOnLoop( // If this loop is nested inside of another one, save the alias information // for when we process the outer loop. - if (CurAST.get() && L->getParentLoop() && !DeleteAST) + if (!MSSAU.get() && CurAST.get() && L->getParentLoop() && !DeleteAST) LoopToAliasSetMap[L] = std::move(CurAST); if (MSSAU.get() && VerifyMemorySSA) @@ -420,6 +489,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, TargetTransformInfo *TTI, Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { // Verify inputs. @@ -463,9 +533,10 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // bool FreeInLoop = false; if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, ORE) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, + ORE) && !I.mayHaveSideEffects()) { - if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE, FreeInLoop)) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) { if (!FreeInLoop) { ++II; eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); @@ -718,6 +789,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && @@ -770,7 +842,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // and we have accurately duplicated the control flow from the loop header // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, ORE) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, + ORE) && isSafeToExecuteUnconditionally( I, DT, CurLoop, SafetyInfo, ORE, CurLoop->getLoopPreheader()->getTerminator())) { @@ -808,13 +881,18 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, continue; } - using namespace PatternMatch; - if (((I.use_empty() && - match(&I, m_Intrinsic<Intrinsic::invariant_start>())) || - isGuard(&I)) && + auto IsInvariantStart = [&](Instruction &I) { + using namespace PatternMatch; + return I.use_empty() && + match(&I, m_Intrinsic<Intrinsic::invariant_start>()); + }; + auto MustExecuteWithoutWritesBefore = [&](Instruction &I) { + return SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) && + SafetyInfo->doesNotWriteMemoryBefore(I, CurLoop); + }; + if ((IsInvariantStart(I) || isGuard(&I)) && CurLoop->hasLoopInvariantOperands(&I) && - SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) && - SafetyInfo->doesNotWriteMemoryBefore(I, CurLoop)) { + MustExecuteWithoutWritesBefore(I)) { hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, MSSAU, ORE); HoistedInstructions.push_back(&I); @@ -867,7 +945,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, LLVM_DEBUG(dbgs() << "LICM rehoisting to " << HoistPoint->getParent()->getName() << ": " << *I << "\n"); - moveInstructionBefore(*I, *HoistPoint, *SafetyInfo); + moveInstructionBefore(*I, *HoistPoint, *SafetyInfo, MSSAU); HoistPoint = I; Changed = true; } @@ -897,8 +975,7 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, Loop *CurLoop) { Value *Addr = LI->getOperand(0); const DataLayout &DL = LI->getModule()->getDataLayout(); - const uint32_t LocSizeInBits = DL.getTypeSizeInBits( - cast<PointerType>(Addr->getType())->getElementType()); + const uint32_t LocSizeInBits = DL.getTypeSizeInBits(LI->getType()); // if the type is i8 addrspace(x)*, we know this is the type of // llvm.invariant.start operand @@ -945,16 +1022,15 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, namespace { /// Return true if-and-only-if we know how to (mechanically) both hoist and /// sink a given instruction out of a loop. Does not address legality -/// concerns such as aliasing or speculation safety. +/// concerns such as aliasing or speculation safety. bool isHoistableAndSinkableInst(Instruction &I) { // Only these instructions are hoistable/sinkable. - return (isa<LoadInst>(I) || isa<StoreInst>(I) || - isa<CallInst>(I) || isa<FenceInst>(I) || - isa<BinaryOperator>(I) || isa<CastInst>(I) || - isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || - isa<CmpInst>(I) || isa<InsertElementInst>(I) || - isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || - isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)); + return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) || + isa<FenceInst>(I) || isa<BinaryOperator>(I) || isa<CastInst>(I) || + isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || + isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || + isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || + isa<InsertValueInst>(I)); } /// Return true if all of the alias sets within this AST are known not to /// contain a Mod, or if MSSA knows thare are no MemoryDefs in the loop. @@ -997,12 +1073,15 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, bool TargetExecutesOncePerLoop, + SinkAndHoistLICMFlags *Flags, OptimizationRemarkEmitter *ORE) { // If we don't understand the instruction, bail early. if (!isHoistableAndSinkableInst(I)) return false; MemorySSA *MSSA = MSSAU ? MSSAU->getMemorySSA() : nullptr; + if (MSSA) + assert(Flags != nullptr && "Flags cannot be null."); // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { @@ -1029,7 +1108,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, CurLoop, AA); else Invalidated = pointerInvalidatedByLoopWithMSSA( - MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop); + MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop, *Flags); // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) @@ -1074,7 +1153,8 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, CurAST, CurLoop, AA); else Invalidated = pointerInvalidatedByLoopWithMSSA( - MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop); + MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop, + *Flags); if (Invalidated) return false; } @@ -1133,13 +1213,46 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, } else { // MSSAU if (isOnlyMemoryAccess(SI, CurLoop, MSSAU)) return true; - if (!EnableLicmCap) { - auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); - if (MSSA->isLiveOnEntryDef(Source) || - !CurLoop->contains(Source->getBlock())) - return true; - } - return false; + // If there are more accesses than the Promotion cap, give up, we're not + // walking a list that long. + if (Flags->NoOfMemAccTooLarge) + return false; + // Check store only if there's still "quota" to check clobber. + if (Flags->LicmMssaOptCounter >= Flags->LicmMssaOptCap) + return false; + // If there are interfering Uses (i.e. their defining access is in the + // loop), or ordered loads (stored as Defs!), don't move this store. + // Could do better here, but this is conservatively correct. + // TODO: Cache set of Uses on the first walk in runOnLoop, update when + // moving accesses. Can also extend to dominating uses. + auto *SIMD = MSSA->getMemoryAccess(SI); + for (auto *BB : CurLoop->getBlocks()) + if (auto *Accesses = MSSA->getBlockAccesses(BB)) { + for (const auto &MA : *Accesses) + if (const auto *MU = dyn_cast<MemoryUse>(&MA)) { + auto *MD = MU->getDefiningAccess(); + if (!MSSA->isLiveOnEntryDef(MD) && + CurLoop->contains(MD->getBlock())) + return false; + // Disable hoisting past potentially interfering loads. Optimized + // Uses may point to an access outside the loop, as getClobbering + // checks the previous iteration when walking the backedge. + // FIXME: More precise: no Uses that alias SI. + if (!Flags->IsSink && !MSSA->dominates(SIMD, MU)) + return false; + } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) + if (auto *LI = dyn_cast<LoadInst>(MD->getMemoryInst())) { + (void)LI; // Silence warning. + assert(!LI->isUnordered() && "Expected unordered load"); + return false; + } + } + + auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); + Flags->LicmMssaOptCounter++; + // If there are no clobbering Defs in the loop, store is safe to hoist. + return MSSA->isLiveOnEntryDef(Source) || + !CurLoop->contains(Source->getBlock()); } } @@ -1233,7 +1346,7 @@ static Instruction *CloneInstructionInExitBlock( // Sinking call-sites need to be handled differently from other // instructions. The cloned call-site needs a funclet bundle operand - // appropriate for it's location in the CFG. + // appropriate for its location in the CFG. SmallVector<OperandBundleDef, 1> OpBundles; for (unsigned BundleIdx = 0, BundleEnd = CI->getNumOperandBundles(); BundleIdx != BundleEnd; ++BundleIdx) { @@ -1310,10 +1423,15 @@ static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, } static void moveInstructionBefore(Instruction &I, Instruction &Dest, - ICFLoopSafetyInfo &SafetyInfo) { + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater *MSSAU) { SafetyInfo.removeInstruction(&I); SafetyInfo.insertInstructionTo(&I, Dest.getParent()); I.moveBefore(&Dest); + if (MSSAU) + if (MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( + MSSAU->getMemorySSA()->getMemoryAccess(&I))) + MSSAU->moveToPlace(OldMemAcc, Dest.getParent(), MemorySSA::End); } static Instruction *sinkThroughTriviallyReplaceablePHI( @@ -1426,8 +1544,7 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE, - bool FreeInLoop) { + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE) { LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) @@ -1441,7 +1558,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, ++NumSunk; // Iterate over users to be ready for actual sinking. Replace users via - // unrechable blocks with undef and make all user PHIs trivially replcable. + // unreachable blocks with undef and make all user PHIs trivially replaceable. SmallPtrSet<Instruction *, 8> VisitedUsers; for (Value::user_iterator UI = I.user_begin(), UE = I.user_end(); UI != UE;) { auto *User = cast<Instruction>(*UI); @@ -1549,25 +1666,15 @@ static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, if (isa<PHINode>(I)) // Move the new node to the end of the phi list in the destination block. - moveInstructionBefore(I, *Dest->getFirstNonPHI(), *SafetyInfo); + moveInstructionBefore(I, *Dest->getFirstNonPHI(), *SafetyInfo, MSSAU); else // Move the new node to the destination block, before its terminator. - moveInstructionBefore(I, *Dest->getTerminator(), *SafetyInfo); - if (MSSAU) { - // If moving, I just moved a load or store, so update MemorySSA. - MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( - MSSAU->getMemorySSA()->getMemoryAccess(&I)); - if (OldMemAcc) - MSSAU->moveToPlace(OldMemAcc, Dest, MemorySSA::End); - } + moveInstructionBefore(I, *Dest->getTerminator(), *SafetyInfo, MSSAU); - // Do not retain debug locations when we are moving instructions to different - // basic blocks, because we want to avoid jumpy line tables. Calls, however, - // need to retain their debug locs because they may be inlined. - // FIXME: How do we retain source locations without causing poor debugging - // behavior? - if (!isa<CallInst>(I)) - I.setDebugLoc(DebugLoc()); + // Apply line 0 debug locations when we are moving instructions to different + // basic blocks because we want to avoid jumpy line tables. + if (const DebugLoc &DL = I.getDebugLoc()) + I.setDebugLoc(DebugLoc::get(0, 0, DL.getScope(), DL.getInlinedAt())); if (isa<LoadInst>(I)) ++NumMovedLoads; @@ -1611,8 +1718,10 @@ class LoopPromoter : public LoadAndStorePromoter { const SmallSetVector<Value *, 8> &PointerMustAliases; SmallVectorImpl<BasicBlock *> &LoopExitBlocks; SmallVectorImpl<Instruction *> &LoopInsertPts; + SmallVectorImpl<MemoryAccess *> &MSSAInsertPts; PredIteratorCache &PredCache; AliasSetTracker &AST; + MemorySSAUpdater *MSSAU; LoopInfo &LI; DebugLoc DL; int Alignment; @@ -1639,15 +1748,16 @@ public: LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, const SmallSetVector<Value *, 8> &PMA, SmallVectorImpl<BasicBlock *> &LEB, - SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, - AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, - bool UnorderedAtomic, const AAMDNodes &AATags, - ICFLoopSafetyInfo &SafetyInfo) + SmallVectorImpl<Instruction *> &LIP, + SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC, + AliasSetTracker &ast, MemorySSAUpdater *MSSAU, LoopInfo &li, + DebugLoc dl, int alignment, bool UnorderedAtomic, + const AAMDNodes &AATags, ICFLoopSafetyInfo &SafetyInfo) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), - LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), - LI(li), DL(std::move(dl)), Alignment(alignment), - UnorderedAtomic(UnorderedAtomic), AATags(AATags), SafetyInfo(SafetyInfo) - {} + LoopExitBlocks(LEB), LoopInsertPts(LIP), MSSAInsertPts(MSSAIP), + PredCache(PIC), AST(ast), MSSAU(MSSAU), LI(li), DL(std::move(dl)), + Alignment(alignment), UnorderedAtomic(UnorderedAtomic), AATags(AATags), + SafetyInfo(SafetyInfo) {} bool isInstInList(Instruction *I, const SmallVectorImpl<Instruction *> &) const override { @@ -1659,7 +1769,7 @@ public: return PointerMustAliases.count(Ptr); } - void doExtraRewritesBeforeFinalDeletion() const override { + void doExtraRewritesBeforeFinalDeletion() override { // Insert stores after in the loop exit blocks. Each exit block gets a // store of the live-out values that feed them. Since we've already told // the SSA updater about the defs in the loop and the preheader @@ -1677,6 +1787,21 @@ public: NewSI->setDebugLoc(DL); if (AATags) NewSI->setAAMetadata(AATags); + + if (MSSAU) { + MemoryAccess *MSSAInsertPoint = MSSAInsertPts[i]; + MemoryAccess *NewMemAcc; + if (!MSSAInsertPoint) { + NewMemAcc = MSSAU->createMemoryAccessInBB( + NewSI, nullptr, NewSI->getParent(), MemorySSA::Beginning); + } else { + NewMemAcc = + MSSAU->createMemoryAccessAfter(NewSI, nullptr, MSSAInsertPoint); + } + MSSAInsertPts[i] = NewMemAcc; + MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); + // FIXME: true for safety, false may still be correct. + } } } @@ -1687,6 +1812,8 @@ public: void instructionDeleted(Instruction *I) const override { SafetyInfo.removeInstruction(I); AST.deleteValue(I); + if (MSSAU) + MSSAU->removeMemoryAccess(I); } }; @@ -1723,10 +1850,11 @@ bool isKnownNonEscaping(Value *Object, const TargetLibraryInfo *TLI) { bool llvm::promoteLoopAccessesToScalars( const SmallSetVector<Value *, 8> &PointerMustAliases, SmallVectorImpl<BasicBlock *> &ExitBlocks, - SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, + SmallVectorImpl<Instruction *> &InsertPts, + SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, AliasSetTracker *CurAST, ICFLoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE) { + Loop *CurLoop, AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && @@ -1827,9 +1955,21 @@ bool llvm::promoteLoopAccessesToScalars( SawUnorderedAtomic |= Load->isAtomic(); SawNotAtomic |= !Load->isAtomic(); - if (!DereferenceableInPH) - DereferenceableInPH = isSafeToExecuteUnconditionally( - *Load, DT, CurLoop, SafetyInfo, ORE, Preheader->getTerminator()); + unsigned InstAlignment = Load->getAlignment(); + if (!InstAlignment) + InstAlignment = + MDL.getABITypeAlignment(Load->getType()); + + // Note that proving a load safe to speculate requires proving + // sufficient alignment at the target location. Proving it guaranteed + // to execute does as well. Thus we can increase our guaranteed + // alignment as well. + if (!DereferenceableInPH || (InstAlignment > Alignment)) + if (isSafeToExecuteUnconditionally(*Load, DT, CurLoop, SafetyInfo, + ORE, Preheader->getTerminator())) { + DereferenceableInPH = true; + Alignment = std::max(Alignment, InstAlignment); + } } else if (const StoreInst *Store = dyn_cast<StoreInst>(UI)) { // Stores *of* the pointer are not interesting, only stores *to* the // pointer. @@ -1875,8 +2015,8 @@ bool llvm::promoteLoopAccessesToScalars( // deref info through it. if (!DereferenceableInPH) { DereferenceableInPH = isDereferenceableAndAlignedPointer( - Store->getPointerOperand(), Store->getAlignment(), MDL, - Preheader->getTerminator(), DT); + Store->getPointerOperand(), Store->getValueOperand()->getType(), + Store->getAlignment(), MDL, Preheader->getTerminator(), DT); } } else return false; // Not a load or store. @@ -1900,6 +2040,14 @@ bool llvm::promoteLoopAccessesToScalars( if (SawUnorderedAtomic && SawNotAtomic) return false; + // If we're inserting an atomic load in the preheader, we must be able to + // lower it. We're only guaranteed to be able to lower naturally aligned + // atomics. + auto *SomePtrElemType = SomePtr->getType()->getPointerElementType(); + if (SawUnorderedAtomic && + Alignment < MDL.getTypeStoreSize(SomePtrElemType)) + return false; + // If we couldn't prove we can hoist the load, bail. if (!DereferenceableInPH) return false; @@ -1943,13 +2091,14 @@ bool llvm::promoteLoopAccessesToScalars( SmallVector<PHINode *, 16> NewPHIs; SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, - InsertPts, PIC, *CurAST, *LI, DL, Alignment, - SawUnorderedAtomic, AATags, *SafetyInfo); + InsertPts, MSSAInsertPts, PIC, *CurAST, MSSAU, *LI, DL, + Alignment, SawUnorderedAtomic, AATags, *SafetyInfo); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. LoadInst *PreheaderLoad = new LoadInst( - SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator()); + SomePtr->getType()->getPointerElementType(), SomePtr, + SomePtr->getName() + ".promoted", Preheader->getTerminator()); if (SawUnorderedAtomic) PreheaderLoad->setOrdering(AtomicOrdering::Unordered); PreheaderLoad->setAlignment(Alignment); @@ -1958,13 +2107,23 @@ bool llvm::promoteLoopAccessesToScalars( PreheaderLoad->setAAMetadata(AATags); SSA.AddAvailableValue(Preheader, PreheaderLoad); + MemoryAccess *PreheaderLoadMemoryAccess; + if (MSSAU) { + PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( + PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); + MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); + MSSAU->insertUse(NewMemUse); + } + // Rewrite all the loads in the loop and remember all the definitions from // stores in the loop. Promoter.run(LoopUses); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); // If the SSAUpdater didn't use the load in the preheader, just zap it now. if (PreheaderLoad->use_empty()) - eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, nullptr); + eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, MSSAU); return true; } @@ -2017,6 +2176,15 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, return CurAST; } +std::unique_ptr<AliasSetTracker> +LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA( + Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) { + auto *MSSA = MSSAU->getMemorySSA(); + auto CurAST = make_unique<AliasSetTracker>(*AA, MSSA, L); + CurAST->addAllInstructionsInLoopUsingMSSA(); + return CurAST; +} + /// Simple analysis hook. Clone alias set info. /// void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, @@ -2095,15 +2263,49 @@ static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, } static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, - Loop *CurLoop) { - MemoryAccess *Source; - // See declaration of EnableLicmCap for usage details. - if (EnableLicmCap) - Source = MU->getDefiningAccess(); - else - Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU); - return !MSSA->isLiveOnEntryDef(Source) && - CurLoop->contains(Source->getBlock()); + Loop *CurLoop, + SinkAndHoistLICMFlags &Flags) { + // For hoisting, use the walker to determine safety + if (!Flags.IsSink) { + MemoryAccess *Source; + // See declaration of SetLicmMssaOptCap for usage details. + if (Flags.LicmMssaOptCounter >= Flags.LicmMssaOptCap) + Source = MU->getDefiningAccess(); + else { + Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU); + Flags.LicmMssaOptCounter++; + } + return !MSSA->isLiveOnEntryDef(Source) && + CurLoop->contains(Source->getBlock()); + } + + // For sinking, we'd need to check all Defs below this use. The getClobbering + // call will look on the backedge of the loop, but will check aliasing with + // the instructions on the previous iteration. + // For example: + // for (i ... ) + // load a[i] ( Use (LoE) + // store a[i] ( 1 = Def (2), with 2 = Phi for the loop. + // i++; + // The load sees no clobbering inside the loop, as the backedge alias check + // does phi translation, and will check aliasing against store a[i-1]. + // However sinking the load outside the loop, below the store is incorrect. + + // For now, only sink if there are no Defs in the loop, and the existing ones + // precede the use and are in the same block. + // FIXME: Increase precision: Safe to sink if Use post dominates the Def; + // needs PostDominatorTreeAnalysis. + // FIXME: More precise: no Defs that alias this Use. + if (Flags.NoOfMemAccTooLarge) + return true; + for (auto *BB : CurLoop->getBlocks()) + if (auto *Accesses = MSSA->getBlockDefs(BB)) + for (const auto &MA : *Accesses) + if (const auto *MD = dyn_cast<MemoryDef>(&MA)) + if (MU->getBlock() != MD->getBlock() || + !MSSA->locallyDominates(MD, MU)) + return true; + return false; } /// Little predicate that returns true if the specified basic block is in diff --git a/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp index a64c99117d64..1c3ff1a61b7e 100644 --- a/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp +++ b/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -1,9 +1,8 @@ //===- LoopAccessAnalysisPrinter.cpp - Loop Access Analysis Printer --------==// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 3b41b5d96c86..1fcf1315a177 100644 --- a/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -1,9 +1,8 @@ //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -313,7 +312,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { IRBuilder<> Builder(MemI); Module *M = BB->getParent()->getParent(); Type *I32 = Type::getInt32Ty(BB->getContext()); - Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch); + Function *PrefetchFunc = + Intrinsic::getDeclaration(M, Intrinsic::prefetch); Builder.CreateCall( PrefetchFunc, {PrefPtrValue, @@ -333,4 +333,3 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { return MadeChange; } - diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index d412025d7e94..8371367e24e7 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -1,9 +1,8 @@ //===- LoopDeletion.cpp - Dead Loop Deletion Pass ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index d797c9dc9e72..f45e5fd0f50b 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -1,9 +1,8 @@ //===- LoopDistribute.cpp - Loop Distribution Pass ------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -767,8 +766,14 @@ public: "cannot isolate unsafe dependencies"); } - // Don't distribute the loop if we need too many SCEV run-time checks. + // Don't distribute the loop if we need too many SCEV run-time checks, or + // any if it's illegal. const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); + if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) { + return fail("RuntimeCheckWithConvergent", + "may not insert runtime check with convergent operation"); + } + if (Pred.getComplexity() > (IsForced.getValueOr(false) ? PragmaDistributeSCEVCheckThreshold : DistributeSCEVCheckThreshold)) @@ -796,7 +801,14 @@ public: auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); + if (LAI->hasConvergentOp() && !Checks.empty()) { + return fail("RuntimeCheckWithConvergent", + "may not insert runtime check with convergent operation"); + } + if (!Pred.isAlwaysTrue() || !Checks.empty()) { + assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning"); + MDNode *OrigLoopID = L->getLoopID(); LLVM_DEBUG(dbgs() << "\nPointers:\n"); diff --git a/lib/Transforms/Scalar/LoopFuse.cpp b/lib/Transforms/Scalar/LoopFuse.cpp new file mode 100644 index 000000000000..0bc2bcff2ae1 --- /dev/null +++ b/lib/Transforms/Scalar/LoopFuse.cpp @@ -0,0 +1,1215 @@ +//===- LoopFuse.cpp - Loop Fusion Pass ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the loop fusion pass. +/// The implementation is largely based on the following document: +/// +/// Code Transformations to Augment the Scope of Loop Fusion in a +/// Production Compiler +/// Christopher Mark Barton +/// MSc Thesis +/// https://webdocs.cs.ualberta.ca/~amaral/thesis/ChristopherBartonMSc.pdf +/// +/// The general approach taken is to collect sets of control flow equivalent +/// loops and test whether they can be fused. The necessary conditions for +/// fusion are: +/// 1. The loops must be adjacent (there cannot be any statements between +/// the two loops). +/// 2. The loops must be conforming (they must execute the same number of +/// iterations). +/// 3. The loops must be control flow equivalent (if one loop executes, the +/// other is guaranteed to execute). +/// 4. There cannot be any negative distance dependencies between the loops. +/// If all of these conditions are satisfied, it is safe to fuse the loops. +/// +/// This implementation creates FusionCandidates that represent the loop and the +/// necessary information needed by fusion. It then operates on the fusion +/// candidates, first confirming that the candidate is eligible for fusion. The +/// candidates are then collected into control flow equivalent sets, sorted in +/// dominance order. Each set of control flow equivalent candidates is then +/// traversed, attempting to fuse pairs of candidates in the set. If all +/// requirements for fusion are met, the two candidates are fused, creating a +/// new (fused) candidate which is then added back into the set to consider for +/// additional fusion. +/// +/// This implementation currently does not make any modifications to remove +/// conditions for fusion. Code transformations to make loops conform to each of +/// the conditions for fusion are discussed in more detail in the document +/// above. These can be added to the current implementation in the future. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopFuse.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "loop-fusion" + +STATISTIC(FuseCounter, "Count number of loop fusions performed"); +STATISTIC(NumFusionCandidates, "Number of candidates for loop fusion"); +STATISTIC(InvalidPreheader, "Loop has invalid preheader"); +STATISTIC(InvalidHeader, "Loop has invalid header"); +STATISTIC(InvalidExitingBlock, "Loop has invalid exiting blocks"); +STATISTIC(InvalidExitBlock, "Loop has invalid exit block"); +STATISTIC(InvalidLatch, "Loop has invalid latch"); +STATISTIC(InvalidLoop, "Loop is invalid"); +STATISTIC(AddressTakenBB, "Basic block has address taken"); +STATISTIC(MayThrowException, "Loop may throw an exception"); +STATISTIC(ContainsVolatileAccess, "Loop contains a volatile access"); +STATISTIC(NotSimplifiedForm, "Loop is not in simplified form"); +STATISTIC(InvalidDependencies, "Dependencies prevent fusion"); +STATISTIC(InvalidTripCount, + "Loop does not have invariant backedge taken count"); +STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop"); +STATISTIC(NonEqualTripCount, "Candidate trip counts are not the same"); +STATISTIC(NonAdjacent, "Candidates are not adjacent"); +STATISTIC(NonEmptyPreheader, "Candidate has a non-empty preheader"); + +enum FusionDependenceAnalysisChoice { + FUSION_DEPENDENCE_ANALYSIS_SCEV, + FUSION_DEPENDENCE_ANALYSIS_DA, + FUSION_DEPENDENCE_ANALYSIS_ALL, +}; + +static cl::opt<FusionDependenceAnalysisChoice> FusionDependenceAnalysis( + "loop-fusion-dependence-analysis", + cl::desc("Which dependence analysis should loop fusion use?"), + cl::values(clEnumValN(FUSION_DEPENDENCE_ANALYSIS_SCEV, "scev", + "Use the scalar evolution interface"), + clEnumValN(FUSION_DEPENDENCE_ANALYSIS_DA, "da", + "Use the dependence analysis interface"), + clEnumValN(FUSION_DEPENDENCE_ANALYSIS_ALL, "all", + "Use all available analyses")), + cl::Hidden, cl::init(FUSION_DEPENDENCE_ANALYSIS_ALL), cl::ZeroOrMore); + +#ifndef NDEBUG +static cl::opt<bool> + VerboseFusionDebugging("loop-fusion-verbose-debug", + cl::desc("Enable verbose debugging for Loop Fusion"), + cl::Hidden, cl::init(false), cl::ZeroOrMore); +#endif + +/// This class is used to represent a candidate for loop fusion. When it is +/// constructed, it checks the conditions for loop fusion to ensure that it +/// represents a valid candidate. It caches several parts of a loop that are +/// used throughout loop fusion (e.g., loop preheader, loop header, etc) instead +/// of continually querying the underlying Loop to retrieve these values. It is +/// assumed these will not change throughout loop fusion. +/// +/// The invalidate method should be used to indicate that the FusionCandidate is +/// no longer a valid candidate for fusion. Similarly, the isValid() method can +/// be used to ensure that the FusionCandidate is still valid for fusion. +struct FusionCandidate { + /// Cache of parts of the loop used throughout loop fusion. These should not + /// need to change throughout the analysis and transformation. + /// These parts are cached to avoid repeatedly looking up in the Loop class. + + /// Preheader of the loop this candidate represents + BasicBlock *Preheader; + /// Header of the loop this candidate represents + BasicBlock *Header; + /// Blocks in the loop that exit the loop + BasicBlock *ExitingBlock; + /// The successor block of this loop (where the exiting blocks go to) + BasicBlock *ExitBlock; + /// Latch of the loop + BasicBlock *Latch; + /// The loop that this fusion candidate represents + Loop *L; + /// Vector of instructions in this loop that read from memory + SmallVector<Instruction *, 16> MemReads; + /// Vector of instructions in this loop that write to memory + SmallVector<Instruction *, 16> MemWrites; + /// Are all of the members of this fusion candidate still valid + bool Valid; + + /// Dominator and PostDominator trees are needed for the + /// FusionCandidateCompare function, required by FusionCandidateSet to + /// determine where the FusionCandidate should be inserted into the set. These + /// are used to establish ordering of the FusionCandidates based on dominance. + const DominatorTree *DT; + const PostDominatorTree *PDT; + + FusionCandidate(Loop *L, const DominatorTree *DT, + const PostDominatorTree *PDT) + : Preheader(L->getLoopPreheader()), Header(L->getHeader()), + ExitingBlock(L->getExitingBlock()), ExitBlock(L->getExitBlock()), + Latch(L->getLoopLatch()), L(L), Valid(true), DT(DT), PDT(PDT) { + + // Walk over all blocks in the loop and check for conditions that may + // prevent fusion. For each block, walk over all instructions and collect + // the memory reads and writes If any instructions that prevent fusion are + // found, invalidate this object and return. + for (BasicBlock *BB : L->blocks()) { + if (BB->hasAddressTaken()) { + AddressTakenBB++; + invalidate(); + return; + } + + for (Instruction &I : *BB) { + if (I.mayThrow()) { + MayThrowException++; + invalidate(); + return; + } + if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { + if (SI->isVolatile()) { + ContainsVolatileAccess++; + invalidate(); + return; + } + } + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + if (LI->isVolatile()) { + ContainsVolatileAccess++; + invalidate(); + return; + } + } + if (I.mayWriteToMemory()) + MemWrites.push_back(&I); + if (I.mayReadFromMemory()) + MemReads.push_back(&I); + } + } + } + + /// Check if all members of the class are valid. + bool isValid() const { + return Preheader && Header && ExitingBlock && ExitBlock && Latch && L && + !L->isInvalid() && Valid; + } + + /// Verify that all members are in sync with the Loop object. + void verify() const { + assert(isValid() && "Candidate is not valid!!"); + assert(!L->isInvalid() && "Loop is invalid!"); + assert(Preheader == L->getLoopPreheader() && "Preheader is out of sync"); + assert(Header == L->getHeader() && "Header is out of sync"); + assert(ExitingBlock == L->getExitingBlock() && + "Exiting Blocks is out of sync"); + assert(ExitBlock == L->getExitBlock() && "Exit block is out of sync"); + assert(Latch == L->getLoopLatch() && "Latch is out of sync"); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const { + dbgs() << "\tPreheader: " << (Preheader ? Preheader->getName() : "nullptr") + << "\n" + << "\tHeader: " << (Header ? Header->getName() : "nullptr") << "\n" + << "\tExitingBB: " + << (ExitingBlock ? ExitingBlock->getName() : "nullptr") << "\n" + << "\tExitBB: " << (ExitBlock ? ExitBlock->getName() : "nullptr") + << "\n" + << "\tLatch: " << (Latch ? Latch->getName() : "nullptr") << "\n"; + } +#endif + +private: + // This is only used internally for now, to clear the MemWrites and MemReads + // list and setting Valid to false. I can't envision other uses of this right + // now, since once FusionCandidates are put into the FusionCandidateSet they + // are immutable. Thus, any time we need to change/update a FusionCandidate, + // we must create a new one and insert it into the FusionCandidateSet to + // ensure the FusionCandidateSet remains ordered correctly. + void invalidate() { + MemWrites.clear(); + MemReads.clear(); + Valid = false; + } +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const FusionCandidate &FC) { + if (FC.isValid()) + OS << FC.Preheader->getName(); + else + OS << "<Invalid>"; + + return OS; +} + +struct FusionCandidateCompare { + /// Comparison functor to sort two Control Flow Equivalent fusion candidates + /// into dominance order. + /// If LHS dominates RHS and RHS post-dominates LHS, return true; + /// IF RHS dominates LHS and LHS post-dominates RHS, return false; + bool operator()(const FusionCandidate &LHS, + const FusionCandidate &RHS) const { + const DominatorTree *DT = LHS.DT; + + // Do not save PDT to local variable as it is only used in asserts and thus + // will trigger an unused variable warning if building without asserts. + assert(DT && LHS.PDT && "Expecting valid dominator tree"); + + // Do this compare first so if LHS == RHS, function returns false. + if (DT->dominates(RHS.Preheader, LHS.Preheader)) { + // RHS dominates LHS + // Verify LHS post-dominates RHS + assert(LHS.PDT->dominates(LHS.Preheader, RHS.Preheader)); + return false; + } + + if (DT->dominates(LHS.Preheader, RHS.Preheader)) { + // Verify RHS Postdominates LHS + assert(LHS.PDT->dominates(RHS.Preheader, LHS.Preheader)); + return true; + } + + // If LHS does not dominate RHS and RHS does not dominate LHS then there is + // no dominance relationship between the two FusionCandidates. Thus, they + // should not be in the same set together. + llvm_unreachable( + "No dominance relationship between these fusion candidates!"); + } +}; + +namespace { +using LoopVector = SmallVector<Loop *, 4>; + +// Set of Control Flow Equivalent (CFE) Fusion Candidates, sorted in dominance +// order. Thus, if FC0 comes *before* FC1 in a FusionCandidateSet, then FC0 +// dominates FC1 and FC1 post-dominates FC0. +// std::set was chosen because we want a sorted data structure with stable +// iterators. A subsequent patch to loop fusion will enable fusing non-ajdacent +// loops by moving intervening code around. When this intervening code contains +// loops, those loops will be moved also. The corresponding FusionCandidates +// will also need to be moved accordingly. As this is done, having stable +// iterators will simplify the logic. Similarly, having an efficient insert that +// keeps the FusionCandidateSet sorted will also simplify the implementation. +using FusionCandidateSet = std::set<FusionCandidate, FusionCandidateCompare>; +using FusionCandidateCollection = SmallVector<FusionCandidateSet, 4>; +} // namespace + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, + const FusionCandidateSet &CandSet) { + for (auto IT : CandSet) + OS << IT << "\n"; + + return OS; +} + +#if !defined(NDEBUG) +static void +printFusionCandidates(const FusionCandidateCollection &FusionCandidates) { + dbgs() << "Fusion Candidates: \n"; + for (const auto &CandidateSet : FusionCandidates) { + dbgs() << "*** Fusion Candidate Set ***\n"; + dbgs() << CandidateSet; + dbgs() << "****************************\n"; + } +} +#endif + +/// Collect all loops in function at the same nest level, starting at the +/// outermost level. +/// +/// This data structure collects all loops at the same nest level for a +/// given function (specified by the LoopInfo object). It starts at the +/// outermost level. +struct LoopDepthTree { + using LoopsOnLevelTy = SmallVector<LoopVector, 4>; + using iterator = LoopsOnLevelTy::iterator; + using const_iterator = LoopsOnLevelTy::const_iterator; + + LoopDepthTree(LoopInfo &LI) : Depth(1) { + if (!LI.empty()) + LoopsOnLevel.emplace_back(LoopVector(LI.rbegin(), LI.rend())); + } + + /// Test whether a given loop has been removed from the function, and thus is + /// no longer valid. + bool isRemovedLoop(const Loop *L) const { return RemovedLoops.count(L); } + + /// Record that a given loop has been removed from the function and is no + /// longer valid. + void removeLoop(const Loop *L) { RemovedLoops.insert(L); } + + /// Descend the tree to the next (inner) nesting level + void descend() { + LoopsOnLevelTy LoopsOnNextLevel; + + for (const LoopVector &LV : *this) + for (Loop *L : LV) + if (!isRemovedLoop(L) && L->begin() != L->end()) + LoopsOnNextLevel.emplace_back(LoopVector(L->begin(), L->end())); + + LoopsOnLevel = LoopsOnNextLevel; + RemovedLoops.clear(); + Depth++; + } + + bool empty() const { return size() == 0; } + size_t size() const { return LoopsOnLevel.size() - RemovedLoops.size(); } + unsigned getDepth() const { return Depth; } + + iterator begin() { return LoopsOnLevel.begin(); } + iterator end() { return LoopsOnLevel.end(); } + const_iterator begin() const { return LoopsOnLevel.begin(); } + const_iterator end() const { return LoopsOnLevel.end(); } + +private: + /// Set of loops that have been removed from the function and are no longer + /// valid. + SmallPtrSet<const Loop *, 8> RemovedLoops; + + /// Depth of the current level, starting at 1 (outermost loops). + unsigned Depth; + + /// Vector of loops at the current depth level that have the same parent loop + LoopsOnLevelTy LoopsOnLevel; +}; + +#ifndef NDEBUG +static void printLoopVector(const LoopVector &LV) { + dbgs() << "****************************\n"; + for (auto L : LV) + printLoop(*L, dbgs()); + dbgs() << "****************************\n"; +} +#endif + +static void reportLoopFusion(const FusionCandidate &FC0, + const FusionCandidate &FC1, + OptimizationRemarkEmitter &ORE) { + using namespace ore; + ORE.emit( + OptimizationRemark(DEBUG_TYPE, "LoopFusion", FC0.Preheader->getParent()) + << "Fused " << NV("Cand1", StringRef(FC0.Preheader->getName())) + << " with " << NV("Cand2", StringRef(FC1.Preheader->getName()))); +} + +struct LoopFuser { +private: + // Sets of control flow equivalent fusion candidates for a given nest level. + FusionCandidateCollection FusionCandidates; + + LoopDepthTree LDT; + DomTreeUpdater DTU; + + LoopInfo &LI; + DominatorTree &DT; + DependenceInfo &DI; + ScalarEvolution &SE; + PostDominatorTree &PDT; + OptimizationRemarkEmitter &ORE; + +public: + LoopFuser(LoopInfo &LI, DominatorTree &DT, DependenceInfo &DI, + ScalarEvolution &SE, PostDominatorTree &PDT, + OptimizationRemarkEmitter &ORE, const DataLayout &DL) + : LDT(LI), DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy), LI(LI), + DT(DT), DI(DI), SE(SE), PDT(PDT), ORE(ORE) {} + + /// This is the main entry point for loop fusion. It will traverse the + /// specified function and collect candidate loops to fuse, starting at the + /// outermost nesting level and working inwards. + bool fuseLoops(Function &F) { +#ifndef NDEBUG + if (VerboseFusionDebugging) { + LI.print(dbgs()); + } +#endif + + LLVM_DEBUG(dbgs() << "Performing Loop Fusion on function " << F.getName() + << "\n"); + bool Changed = false; + + while (!LDT.empty()) { + LLVM_DEBUG(dbgs() << "Got " << LDT.size() << " loop sets for depth " + << LDT.getDepth() << "\n";); + + for (const LoopVector &LV : LDT) { + assert(LV.size() > 0 && "Empty loop set was build!"); + + // Skip singleton loop sets as they do not offer fusion opportunities on + // this level. + if (LV.size() == 1) + continue; +#ifndef NDEBUG + if (VerboseFusionDebugging) { + LLVM_DEBUG({ + dbgs() << " Visit loop set (#" << LV.size() << "):\n"; + printLoopVector(LV); + }); + } +#endif + + collectFusionCandidates(LV); + Changed |= fuseCandidates(); + } + + // Finished analyzing candidates at this level. + // Descend to the next level and clear all of the candidates currently + // collected. Note that it will not be possible to fuse any of the + // existing candidates with new candidates because the new candidates will + // be at a different nest level and thus not be control flow equivalent + // with all of the candidates collected so far. + LLVM_DEBUG(dbgs() << "Descend one level!\n"); + LDT.descend(); + FusionCandidates.clear(); + } + + if (Changed) + LLVM_DEBUG(dbgs() << "Function after Loop Fusion: \n"; F.dump();); + +#ifndef NDEBUG + assert(DT.verify()); + assert(PDT.verify()); + LI.verify(DT); + SE.verify(); +#endif + + LLVM_DEBUG(dbgs() << "Loop Fusion complete\n"); + return Changed; + } + +private: + /// Determine if two fusion candidates are control flow equivalent. + /// + /// Two fusion candidates are control flow equivalent if when one executes, + /// the other is guaranteed to execute. This is determined using dominators + /// and post-dominators: if A dominates B and B post-dominates A then A and B + /// are control-flow equivalent. + bool isControlFlowEquivalent(const FusionCandidate &FC0, + const FusionCandidate &FC1) const { + assert(FC0.Preheader && FC1.Preheader && "Expecting valid preheaders"); + + if (DT.dominates(FC0.Preheader, FC1.Preheader)) + return PDT.dominates(FC1.Preheader, FC0.Preheader); + + if (DT.dominates(FC1.Preheader, FC0.Preheader)) + return PDT.dominates(FC0.Preheader, FC1.Preheader); + + return false; + } + + /// Determine if a fusion candidate (representing a loop) is eligible for + /// fusion. Note that this only checks whether a single loop can be fused - it + /// does not check whether it is *legal* to fuse two loops together. + bool eligibleForFusion(const FusionCandidate &FC) const { + if (!FC.isValid()) { + LLVM_DEBUG(dbgs() << "FC " << FC << " has invalid CFG requirements!\n"); + if (!FC.Preheader) + InvalidPreheader++; + if (!FC.Header) + InvalidHeader++; + if (!FC.ExitingBlock) + InvalidExitingBlock++; + if (!FC.ExitBlock) + InvalidExitBlock++; + if (!FC.Latch) + InvalidLatch++; + if (FC.L->isInvalid()) + InvalidLoop++; + + return false; + } + + // Require ScalarEvolution to be able to determine a trip count. + if (!SE.hasLoopInvariantBackedgeTakenCount(FC.L)) { + LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName() + << " trip count not computable!\n"); + InvalidTripCount++; + return false; + } + + if (!FC.L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Loop " << FC.L->getName() + << " is not in simplified form!\n"); + NotSimplifiedForm++; + return false; + } + + return true; + } + + /// Iterate over all loops in the given loop set and identify the loops that + /// are eligible for fusion. Place all eligible fusion candidates into Control + /// Flow Equivalent sets, sorted by dominance. + void collectFusionCandidates(const LoopVector &LV) { + for (Loop *L : LV) { + FusionCandidate CurrCand(L, &DT, &PDT); + if (!eligibleForFusion(CurrCand)) + continue; + + // Go through each list in FusionCandidates and determine if L is control + // flow equivalent with the first loop in that list. If it is, append LV. + // If not, go to the next list. + // If no suitable list is found, start another list and add it to + // FusionCandidates. + bool FoundSet = false; + + for (auto &CurrCandSet : FusionCandidates) { + if (isControlFlowEquivalent(*CurrCandSet.begin(), CurrCand)) { + CurrCandSet.insert(CurrCand); + FoundSet = true; +#ifndef NDEBUG + if (VerboseFusionDebugging) + LLVM_DEBUG(dbgs() << "Adding " << CurrCand + << " to existing candidate set\n"); +#endif + break; + } + } + if (!FoundSet) { + // No set was found. Create a new set and add to FusionCandidates +#ifndef NDEBUG + if (VerboseFusionDebugging) + LLVM_DEBUG(dbgs() << "Adding " << CurrCand << " to new set\n"); +#endif + FusionCandidateSet NewCandSet; + NewCandSet.insert(CurrCand); + FusionCandidates.push_back(NewCandSet); + } + NumFusionCandidates++; + } + } + + /// Determine if it is beneficial to fuse two loops. + /// + /// For now, this method simply returns true because we want to fuse as much + /// as possible (primarily to test the pass). This method will evolve, over + /// time, to add heuristics for profitability of fusion. + bool isBeneficialFusion(const FusionCandidate &FC0, + const FusionCandidate &FC1) { + return true; + } + + /// Determine if two fusion candidates have the same trip count (i.e., they + /// execute the same number of iterations). + /// + /// Note that for now this method simply returns a boolean value because there + /// are no mechanisms in loop fusion to handle different trip counts. In the + /// future, this behaviour can be extended to adjust one of the loops to make + /// the trip counts equal (e.g., loop peeling). When this is added, this + /// interface may need to change to return more information than just a + /// boolean value. + bool identicalTripCounts(const FusionCandidate &FC0, + const FusionCandidate &FC1) const { + const SCEV *TripCount0 = SE.getBackedgeTakenCount(FC0.L); + if (isa<SCEVCouldNotCompute>(TripCount0)) { + UncomputableTripCount++; + LLVM_DEBUG(dbgs() << "Trip count of first loop could not be computed!"); + return false; + } + + const SCEV *TripCount1 = SE.getBackedgeTakenCount(FC1.L); + if (isa<SCEVCouldNotCompute>(TripCount1)) { + UncomputableTripCount++; + LLVM_DEBUG(dbgs() << "Trip count of second loop could not be computed!"); + return false; + } + LLVM_DEBUG(dbgs() << "\tTrip counts: " << *TripCount0 << " & " + << *TripCount1 << " are " + << (TripCount0 == TripCount1 ? "identical" : "different") + << "\n"); + + return (TripCount0 == TripCount1); + } + + /// Walk each set of control flow equivalent fusion candidates and attempt to + /// fuse them. This does a single linear traversal of all candidates in the + /// set. The conditions for legal fusion are checked at this point. If a pair + /// of fusion candidates passes all legality checks, they are fused together + /// and a new fusion candidate is created and added to the FusionCandidateSet. + /// The original fusion candidates are then removed, as they are no longer + /// valid. + bool fuseCandidates() { + bool Fused = false; + LLVM_DEBUG(printFusionCandidates(FusionCandidates)); + for (auto &CandidateSet : FusionCandidates) { + if (CandidateSet.size() < 2) + continue; + + LLVM_DEBUG(dbgs() << "Attempting fusion on Candidate Set:\n" + << CandidateSet << "\n"); + + for (auto FC0 = CandidateSet.begin(); FC0 != CandidateSet.end(); ++FC0) { + assert(!LDT.isRemovedLoop(FC0->L) && + "Should not have removed loops in CandidateSet!"); + auto FC1 = FC0; + for (++FC1; FC1 != CandidateSet.end(); ++FC1) { + assert(!LDT.isRemovedLoop(FC1->L) && + "Should not have removed loops in CandidateSet!"); + + LLVM_DEBUG(dbgs() << "Attempting to fuse candidate \n"; FC0->dump(); + dbgs() << " with\n"; FC1->dump(); dbgs() << "\n"); + + FC0->verify(); + FC1->verify(); + + if (!identicalTripCounts(*FC0, *FC1)) { + LLVM_DEBUG(dbgs() << "Fusion candidates do not have identical trip " + "counts. Not fusing.\n"); + NonEqualTripCount++; + continue; + } + + if (!isAdjacent(*FC0, *FC1)) { + LLVM_DEBUG(dbgs() + << "Fusion candidates are not adjacent. Not fusing.\n"); + NonAdjacent++; + continue; + } + + // For now we skip fusing if the second candidate has any instructions + // in the preheader. This is done because we currently do not have the + // safety checks to determine if it is save to move the preheader of + // the second candidate past the body of the first candidate. Once + // these checks are added, this condition can be removed. + if (!isEmptyPreheader(*FC1)) { + LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty " + "preheader. Not fusing.\n"); + NonEmptyPreheader++; + continue; + } + + if (!dependencesAllowFusion(*FC0, *FC1)) { + LLVM_DEBUG(dbgs() << "Memory dependencies do not allow fusion!\n"); + continue; + } + + bool BeneficialToFuse = isBeneficialFusion(*FC0, *FC1); + LLVM_DEBUG(dbgs() + << "\tFusion appears to be " + << (BeneficialToFuse ? "" : "un") << "profitable!\n"); + if (!BeneficialToFuse) + continue; + + // All analysis has completed and has determined that fusion is legal + // and profitable. At this point, start transforming the code and + // perform fusion. + + LLVM_DEBUG(dbgs() << "\tFusion is performed: " << *FC0 << " and " + << *FC1 << "\n"); + + // Report fusion to the Optimization Remarks. + // Note this needs to be done *before* performFusion because + // performFusion will change the original loops, making it not + // possible to identify them after fusion is complete. + reportLoopFusion(*FC0, *FC1, ORE); + + FusionCandidate FusedCand(performFusion(*FC0, *FC1), &DT, &PDT); + FusedCand.verify(); + assert(eligibleForFusion(FusedCand) && + "Fused candidate should be eligible for fusion!"); + + // Notify the loop-depth-tree that these loops are not valid objects + // anymore. + LDT.removeLoop(FC1->L); + + CandidateSet.erase(FC0); + CandidateSet.erase(FC1); + + auto InsertPos = CandidateSet.insert(FusedCand); + + assert(InsertPos.second && + "Unable to insert TargetCandidate in CandidateSet!"); + + // Reset FC0 and FC1 the new (fused) candidate. Subsequent iterations + // of the FC1 loop will attempt to fuse the new (fused) loop with the + // remaining candidates in the current candidate set. + FC0 = FC1 = InsertPos.first; + + LLVM_DEBUG(dbgs() << "Candidate Set (after fusion): " << CandidateSet + << "\n"); + + Fused = true; + } + } + } + return Fused; + } + + /// Rewrite all additive recurrences in a SCEV to use a new loop. + class AddRecLoopReplacer : public SCEVRewriteVisitor<AddRecLoopReplacer> { + public: + AddRecLoopReplacer(ScalarEvolution &SE, const Loop &OldL, const Loop &NewL, + bool UseMax = true) + : SCEVRewriteVisitor(SE), Valid(true), UseMax(UseMax), OldL(OldL), + NewL(NewL) {} + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + const Loop *ExprL = Expr->getLoop(); + SmallVector<const SCEV *, 2> Operands; + if (ExprL == &OldL) { + Operands.append(Expr->op_begin(), Expr->op_end()); + return SE.getAddRecExpr(Operands, &NewL, Expr->getNoWrapFlags()); + } + + if (OldL.contains(ExprL)) { + bool Pos = SE.isKnownPositive(Expr->getStepRecurrence(SE)); + if (!UseMax || !Pos || !Expr->isAffine()) { + Valid = false; + return Expr; + } + return visit(Expr->getStart()); + } + + for (const SCEV *Op : Expr->operands()) + Operands.push_back(visit(Op)); + return SE.getAddRecExpr(Operands, ExprL, Expr->getNoWrapFlags()); + } + + bool wasValidSCEV() const { return Valid; } + + private: + bool Valid, UseMax; + const Loop &OldL, &NewL; + }; + + /// Return false if the access functions of \p I0 and \p I1 could cause + /// a negative dependence. + bool accessDiffIsPositive(const Loop &L0, const Loop &L1, Instruction &I0, + Instruction &I1, bool EqualIsInvalid) { + Value *Ptr0 = getLoadStorePointerOperand(&I0); + Value *Ptr1 = getLoadStorePointerOperand(&I1); + if (!Ptr0 || !Ptr1) + return false; + + const SCEV *SCEVPtr0 = SE.getSCEVAtScope(Ptr0, &L0); + const SCEV *SCEVPtr1 = SE.getSCEVAtScope(Ptr1, &L1); +#ifndef NDEBUG + if (VerboseFusionDebugging) + LLVM_DEBUG(dbgs() << " Access function check: " << *SCEVPtr0 << " vs " + << *SCEVPtr1 << "\n"); +#endif + AddRecLoopReplacer Rewriter(SE, L0, L1); + SCEVPtr0 = Rewriter.visit(SCEVPtr0); +#ifndef NDEBUG + if (VerboseFusionDebugging) + LLVM_DEBUG(dbgs() << " Access function after rewrite: " << *SCEVPtr0 + << " [Valid: " << Rewriter.wasValidSCEV() << "]\n"); +#endif + if (!Rewriter.wasValidSCEV()) + return false; + + // TODO: isKnownPredicate doesnt work well when one SCEV is loop carried (by + // L0) and the other is not. We could check if it is monotone and test + // the beginning and end value instead. + + BasicBlock *L0Header = L0.getHeader(); + auto HasNonLinearDominanceRelation = [&](const SCEV *S) { + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S); + if (!AddRec) + return false; + return !DT.dominates(L0Header, AddRec->getLoop()->getHeader()) && + !DT.dominates(AddRec->getLoop()->getHeader(), L0Header); + }; + if (SCEVExprContains(SCEVPtr1, HasNonLinearDominanceRelation)) + return false; + + ICmpInst::Predicate Pred = + EqualIsInvalid ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_SGE; + bool IsAlwaysGE = SE.isKnownPredicate(Pred, SCEVPtr0, SCEVPtr1); +#ifndef NDEBUG + if (VerboseFusionDebugging) + LLVM_DEBUG(dbgs() << " Relation: " << *SCEVPtr0 + << (IsAlwaysGE ? " >= " : " may < ") << *SCEVPtr1 + << "\n"); +#endif + return IsAlwaysGE; + } + + /// Return true if the dependences between @p I0 (in @p L0) and @p I1 (in + /// @p L1) allow loop fusion of @p L0 and @p L1. The dependence analyses + /// specified by @p DepChoice are used to determine this. + bool dependencesAllowFusion(const FusionCandidate &FC0, + const FusionCandidate &FC1, Instruction &I0, + Instruction &I1, bool AnyDep, + FusionDependenceAnalysisChoice DepChoice) { +#ifndef NDEBUG + if (VerboseFusionDebugging) { + LLVM_DEBUG(dbgs() << "Check dep: " << I0 << " vs " << I1 << " : " + << DepChoice << "\n"); + } +#endif + switch (DepChoice) { + case FUSION_DEPENDENCE_ANALYSIS_SCEV: + return accessDiffIsPositive(*FC0.L, *FC1.L, I0, I1, AnyDep); + case FUSION_DEPENDENCE_ANALYSIS_DA: { + auto DepResult = DI.depends(&I0, &I1, true); + if (!DepResult) + return true; +#ifndef NDEBUG + if (VerboseFusionDebugging) { + LLVM_DEBUG(dbgs() << "DA res: "; DepResult->dump(dbgs()); + dbgs() << " [#l: " << DepResult->getLevels() << "][Ordered: " + << (DepResult->isOrdered() ? "true" : "false") + << "]\n"); + LLVM_DEBUG(dbgs() << "DepResult Levels: " << DepResult->getLevels() + << "\n"); + } +#endif + + if (DepResult->getNextPredecessor() || DepResult->getNextSuccessor()) + LLVM_DEBUG( + dbgs() << "TODO: Implement pred/succ dependence handling!\n"); + + // TODO: Can we actually use the dependence info analysis here? + return false; + } + + case FUSION_DEPENDENCE_ANALYSIS_ALL: + return dependencesAllowFusion(FC0, FC1, I0, I1, AnyDep, + FUSION_DEPENDENCE_ANALYSIS_SCEV) || + dependencesAllowFusion(FC0, FC1, I0, I1, AnyDep, + FUSION_DEPENDENCE_ANALYSIS_DA); + } + + llvm_unreachable("Unknown fusion dependence analysis choice!"); + } + + /// Perform a dependence check and return if @p FC0 and @p FC1 can be fused. + bool dependencesAllowFusion(const FusionCandidate &FC0, + const FusionCandidate &FC1) { + LLVM_DEBUG(dbgs() << "Check if " << FC0 << " can be fused with " << FC1 + << "\n"); + assert(FC0.L->getLoopDepth() == FC1.L->getLoopDepth()); + assert(DT.dominates(FC0.Preheader, FC1.Preheader)); + + for (Instruction *WriteL0 : FC0.MemWrites) { + for (Instruction *WriteL1 : FC1.MemWrites) + if (!dependencesAllowFusion(FC0, FC1, *WriteL0, *WriteL1, + /* AnyDep */ false, + FusionDependenceAnalysis)) { + InvalidDependencies++; + return false; + } + for (Instruction *ReadL1 : FC1.MemReads) + if (!dependencesAllowFusion(FC0, FC1, *WriteL0, *ReadL1, + /* AnyDep */ false, + FusionDependenceAnalysis)) { + InvalidDependencies++; + return false; + } + } + + for (Instruction *WriteL1 : FC1.MemWrites) { + for (Instruction *WriteL0 : FC0.MemWrites) + if (!dependencesAllowFusion(FC0, FC1, *WriteL0, *WriteL1, + /* AnyDep */ false, + FusionDependenceAnalysis)) { + InvalidDependencies++; + return false; + } + for (Instruction *ReadL0 : FC0.MemReads) + if (!dependencesAllowFusion(FC0, FC1, *ReadL0, *WriteL1, + /* AnyDep */ false, + FusionDependenceAnalysis)) { + InvalidDependencies++; + return false; + } + } + + // Walk through all uses in FC1. For each use, find the reaching def. If the + // def is located in FC0 then it is is not safe to fuse. + for (BasicBlock *BB : FC1.L->blocks()) + for (Instruction &I : *BB) + for (auto &Op : I.operands()) + if (Instruction *Def = dyn_cast<Instruction>(Op)) + if (FC0.L->contains(Def->getParent())) { + InvalidDependencies++; + return false; + } + + return true; + } + + /// Determine if the exit block of \p FC0 is the preheader of \p FC1. In this + /// case, there is no code in between the two fusion candidates, thus making + /// them adjacent. + bool isAdjacent(const FusionCandidate &FC0, + const FusionCandidate &FC1) const { + return FC0.ExitBlock == FC1.Preheader; + } + + bool isEmptyPreheader(const FusionCandidate &FC) const { + return FC.Preheader->size() == 1; + } + + /// Fuse two fusion candidates, creating a new fused loop. + /// + /// This method contains the mechanics of fusing two loops, represented by \p + /// FC0 and \p FC1. It is assumed that \p FC0 dominates \p FC1 and \p FC1 + /// postdominates \p FC0 (making them control flow equivalent). It also + /// assumes that the other conditions for fusion have been met: adjacent, + /// identical trip counts, and no negative distance dependencies exist that + /// would prevent fusion. Thus, there is no checking for these conditions in + /// this method. + /// + /// Fusion is performed by rewiring the CFG to update successor blocks of the + /// components of tho loop. Specifically, the following changes are done: + /// + /// 1. The preheader of \p FC1 is removed as it is no longer necessary + /// (because it is currently only a single statement block). + /// 2. The latch of \p FC0 is modified to jump to the header of \p FC1. + /// 3. The latch of \p FC1 i modified to jump to the header of \p FC0. + /// 4. All blocks from \p FC1 are removed from FC1 and added to FC0. + /// + /// All of these modifications are done with dominator tree updates, thus + /// keeping the dominator (and post dominator) information up-to-date. + /// + /// This can be improved in the future by actually merging blocks during + /// fusion. For example, the preheader of \p FC1 can be merged with the + /// preheader of \p FC0. This would allow loops with more than a single + /// statement in the preheader to be fused. Similarly, the latch blocks of the + /// two loops could also be fused into a single block. This will require + /// analysis to prove it is safe to move the contents of the block past + /// existing code, which currently has not been implemented. + Loop *performFusion(const FusionCandidate &FC0, const FusionCandidate &FC1) { + assert(FC0.isValid() && FC1.isValid() && + "Expecting valid fusion candidates"); + + LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump(); + dbgs() << "Fusion Candidate 1: \n"; FC1.dump();); + + assert(FC1.Preheader == FC0.ExitBlock); + assert(FC1.Preheader->size() == 1 && + FC1.Preheader->getSingleSuccessor() == FC1.Header); + + // Remember the phi nodes originally in the header of FC0 in order to rewire + // them later. However, this is only necessary if the new loop carried + // values might not dominate the exiting branch. While we do not generally + // test if this is the case but simply insert intermediate phi nodes, we + // need to make sure these intermediate phi nodes have different + // predecessors. To this end, we filter the special case where the exiting + // block is the latch block of the first loop. Nothing needs to be done + // anyway as all loop carried values dominate the latch and thereby also the + // exiting branch. + SmallVector<PHINode *, 8> OriginalFC0PHIs; + if (FC0.ExitingBlock != FC0.Latch) + for (PHINode &PHI : FC0.Header->phis()) + OriginalFC0PHIs.push_back(&PHI); + + // Replace incoming blocks for header PHIs first. + FC1.Preheader->replaceSuccessorsPhiUsesWith(FC0.Preheader); + FC0.Latch->replaceSuccessorsPhiUsesWith(FC1.Latch); + + // Then modify the control flow and update DT and PDT. + SmallVector<DominatorTree::UpdateType, 8> TreeUpdates; + + // The old exiting block of the first loop (FC0) has to jump to the header + // of the second as we need to execute the code in the second header block + // regardless of the trip count. That is, if the trip count is 0, so the + // back edge is never taken, we still have to execute both loop headers, + // especially (but not only!) if the second is a do-while style loop. + // However, doing so might invalidate the phi nodes of the first loop as + // the new values do only need to dominate their latch and not the exiting + // predicate. To remedy this potential problem we always introduce phi + // nodes in the header of the second loop later that select the loop carried + // value, if the second header was reached through an old latch of the + // first, or undef otherwise. This is sound as exiting the first implies the + // second will exit too, __without__ taking the back-edge. [Their + // trip-counts are equal after all. + // KB: Would this sequence be simpler to just just make FC0.ExitingBlock go + // to FC1.Header? I think this is basically what the three sequences are + // trying to accomplish; however, doing this directly in the CFG may mean + // the DT/PDT becomes invalid + FC0.ExitingBlock->getTerminator()->replaceUsesOfWith(FC1.Preheader, + FC1.Header); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC0.ExitingBlock, FC1.Preheader)); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Insert, FC0.ExitingBlock, FC1.Header)); + + // The pre-header of L1 is not necessary anymore. + assert(pred_begin(FC1.Preheader) == pred_end(FC1.Preheader)); + FC1.Preheader->getTerminator()->eraseFromParent(); + new UnreachableInst(FC1.Preheader->getContext(), FC1.Preheader); + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Delete, FC1.Preheader, FC1.Header)); + + // Moves the phi nodes from the second to the first loops header block. + while (PHINode *PHI = dyn_cast<PHINode>(&FC1.Header->front())) { + if (SE.isSCEVable(PHI->getType())) + SE.forgetValue(PHI); + if (PHI->hasNUsesOrMore(1)) + PHI->moveBefore(&*FC0.Header->getFirstInsertionPt()); + else + PHI->eraseFromParent(); + } + + // Introduce new phi nodes in the second loop header to ensure + // exiting the first and jumping to the header of the second does not break + // the SSA property of the phis originally in the first loop. See also the + // comment above. + Instruction *L1HeaderIP = &FC1.Header->front(); + for (PHINode *LCPHI : OriginalFC0PHIs) { + int L1LatchBBIdx = LCPHI->getBasicBlockIndex(FC1.Latch); + assert(L1LatchBBIdx >= 0 && + "Expected loop carried value to be rewired at this point!"); + + Value *LCV = LCPHI->getIncomingValue(L1LatchBBIdx); + + PHINode *L1HeaderPHI = PHINode::Create( + LCV->getType(), 2, LCPHI->getName() + ".afterFC0", L1HeaderIP); + L1HeaderPHI->addIncoming(LCV, FC0.Latch); + L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()), + FC0.ExitingBlock); + + LCPHI->setIncomingValue(L1LatchBBIdx, L1HeaderPHI); + } + + // Replace latch terminator destinations. + FC0.Latch->getTerminator()->replaceUsesOfWith(FC0.Header, FC1.Header); + FC1.Latch->getTerminator()->replaceUsesOfWith(FC1.Header, FC0.Header); + + // If FC0.Latch and FC0.ExitingBlock are the same then we have already + // performed the updates above. + if (FC0.Latch != FC0.ExitingBlock) + TreeUpdates.emplace_back(DominatorTree::UpdateType( + DominatorTree::Insert, FC0.Latch, FC1.Header)); + + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete, + FC0.Latch, FC0.Header)); + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Insert, + FC1.Latch, FC0.Header)); + TreeUpdates.emplace_back(DominatorTree::UpdateType(DominatorTree::Delete, + FC1.Latch, FC1.Header)); + + // Update DT/PDT + DTU.applyUpdates(TreeUpdates); + + LI.removeBlock(FC1.Preheader); + DTU.deleteBB(FC1.Preheader); + DTU.flush(); + + // Is there a way to keep SE up-to-date so we don't need to forget the loops + // and rebuild the information in subsequent passes of fusion? + SE.forgetLoop(FC1.L); + SE.forgetLoop(FC0.L); + + // Merge the loops. + SmallVector<BasicBlock *, 8> Blocks(FC1.L->block_begin(), + FC1.L->block_end()); + for (BasicBlock *BB : Blocks) { + FC0.L->addBlockEntry(BB); + FC1.L->removeBlockFromLoop(BB); + if (LI.getLoopFor(BB) != FC1.L) + continue; + LI.changeLoopFor(BB, FC0.L); + } + while (!FC1.L->empty()) { + const auto &ChildLoopIt = FC1.L->begin(); + Loop *ChildLoop = *ChildLoopIt; + FC1.L->removeChildLoop(ChildLoopIt); + FC0.L->addChildLoop(ChildLoop); + } + + // Delete the now empty loop L1. + LI.erase(FC1.L); + +#ifndef NDEBUG + assert(!verifyFunction(*FC0.Header->getParent(), &errs())); + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); + assert(PDT.verify()); + LI.verify(DT); + SE.verify(); +#endif + + FuseCounter++; + + LLVM_DEBUG(dbgs() << "Fusion done:\n"); + + return FC0.L; + } +}; + +struct LoopFuseLegacy : public FunctionPass { + + static char ID; + + LoopFuseLegacy() : FunctionPass(ID) { + initializeLoopFuseLegacyPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<DependenceAnalysisWrapperPass>(); + + AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<PostDominatorTreeWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + + const DataLayout &DL = F.getParent()->getDataLayout(); + LoopFuser LF(LI, DT, DI, SE, PDT, ORE, DL); + return LF.fuseLoops(F); + } +}; + +PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &DI = AM.getResult<DependenceAnalysis>(F); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + const DataLayout &DL = F.getParent()->getDataLayout(); + LoopFuser LF(LI, DT, DI, SE, PDT, ORE, DL); + bool Changed = LF.fuseLoops(F); + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; +} + +char LoopFuseLegacy::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false, + false) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(LoopFuseLegacy, "loop-fusion", "Loop Fusion", false, false) + +FunctionPass *llvm::createLoopFusePass() { return new LoopFuseLegacy(); } diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index fbffa1920a84..e561494f19cf 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -1,9 +1,8 @@ //===- LoopIdiomRecognize.cpp - Loop idiom recognition --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -37,6 +36,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -51,12 +51,12 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -87,8 +87,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> #include <cassert> @@ -120,6 +120,7 @@ class LoopIdiomRecognize { TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; const DataLayout *DL; + OptimizationRemarkEmitter &ORE; bool ApplyCodeSizeHeuristics; public: @@ -127,8 +128,9 @@ public: LoopInfo *LI, ScalarEvolution *SE, TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, - const DataLayout *DL) - : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL) {} + const DataLayout *DL, + OptimizationRemarkEmitter &ORE) + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {} bool runOnLoop(Loop *L); @@ -221,7 +223,12 @@ public: *L->getHeader()->getParent()); const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); - LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); + + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE); return LIR.runOnLoop(L); } @@ -243,7 +250,19 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, LPMUpdater &) { const auto *DL = &L.getHeader()->getModule()->getDataLayout(); - LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL); + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error( + "LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached " + "at a higher level"); + + LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL, + *ORE); if (!LIR.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -285,7 +304,7 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) { // Determine if code size heuristics need to be applied. ApplyCodeSizeHeuristics = - L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs; + L->getHeader()->getParent()->hasOptSize() && UseLIRCodeSizeHeurs; HasMemset = TLI->has(LibFunc_memset); HasMemsetPattern = TLI->has(LibFunc_memset_pattern16); @@ -313,9 +332,10 @@ bool LoopIdiomRecognize::runOnCountableLoop() { SmallVector<BasicBlock *, 8> ExitBlocks; CurLoop->getUniqueExitBlocks(ExitBlocks); - LLVM_DEBUG(dbgs() << "loop-idiom Scanning: F[" + LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << CurLoop->getHeader()->getParent()->getName() - << "] Loop %" << CurLoop->getHeader()->getName() << "\n"); + << "] Countable Loop %" << CurLoop->getHeader()->getName() + << "\n"); bool MadeChange = false; @@ -430,7 +450,7 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) { // turned into a memset of i8 -1, assuming that all the consecutive bytes // are stored. A store of i32 0x01020304 can never be turned into a memset, // but it can be turned into memset_pattern if the target supports it. - Value *SplatValue = isBytewiseValue(StoredVal); + Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; // Note: memset and memset_pattern on unordered-atomic is yet not supported @@ -607,7 +627,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, Constant *FirstPatternValue = nullptr; if (For == ForMemset::Yes) - FirstSplatValue = isBytewiseValue(FirstStoredVal); + FirstSplatValue = isBytewiseValue(FirstStoredVal, *DL); else FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); @@ -640,7 +660,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, Constant *SecondPatternValue = nullptr; if (For == ForMemset::Yes) - SecondSplatValue = isBytewiseValue(SecondStoredVal); + SecondSplatValue = isBytewiseValue(SecondStoredVal, *DL); else SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); @@ -860,7 +880,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool NegStride, bool IsLoopMemset) { - Value *SplatValue = isBytewiseValue(StoredVal); + Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; if (!SplatValue) @@ -931,9 +951,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( Module *M = TheStore->getModule(); StringRef FuncName = "memset_pattern16"; - Value *MSP = - M->getOrInsertFunction(FuncName, Builder.getVoidTy(), - Int8PtrTy, Int8PtrTy, IntPtr); + FunctionCallee MSP = M->getOrInsertFunction(FuncName, Builder.getVoidTy(), + Int8PtrTy, Int8PtrTy, IntPtr); inferLibFuncAttributes(M, FuncName, *TLI); // Otherwise we should form a memset_pattern16. PatternValue is known to be @@ -952,6 +971,14 @@ bool LoopIdiomRecognize::processLoopStridedStore( << "\n"); NewCall->setDebugLoc(TheStore->getDebugLoc()); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStridedStore", + NewCall->getDebugLoc(), Preheader) + << "Transformed loop-strided store into a call to " + << ore::NV("NewFunction", NewCall->getCalledFunction()) + << "() function"; + }); + // Okay, the memset has been formed. Zap the original store and anything that // feeds into it. for (auto *I : Stores) @@ -1084,6 +1111,14 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, << " from store ptr=" << *StoreEv << " at: " << *SI << "\n"); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStoreOfLoopLoad", + NewCall->getDebugLoc(), Preheader) + << "Formed a call to " + << ore::NV("NewFunction", NewCall->getCalledFunction()) + << "() function"; + }); + // Okay, the memcpy has been formed. Zap the original store and anything that // feeds into it. deleteDeadInstruction(SI); @@ -1109,6 +1144,11 @@ bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, } bool LoopIdiomRecognize::runOnNoncountableLoop() { + LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" + << CurLoop->getHeader()->getParent()->getName() + << "] Noncountable Loop %" + << CurLoop->getHeader()->getName() << "\n"); + return recognizePopcount() || recognizeAndInsertFFS(); } @@ -1462,9 +1502,15 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() { const Value *Args[] = {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) : ConstantInt::getFalse(InitX->getContext())}; - if (CurLoop->getHeader()->size() != IdiomCanonicalSize && + + // @llvm.dbg doesn't count as they have no semantic effect. + auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug(); + uint32_t HeaderSize = + std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); + + if (HeaderSize != IdiomCanonicalSize && TTI->getIntrinsicCost(IntrinID, InitX->getType(), Args) > - TargetTransformInfo::TCC_Basic) + TargetTransformInfo::TCC_Basic) return false; transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, @@ -1529,7 +1575,7 @@ static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val, Type *Tys[] = {Val->getType()}; Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); - Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctpop, Tys); + Function *Func = Intrinsic::getDeclaration(M, Intrinsic::ctpop, Tys); CallInst *CI = IRBuilder.CreateCall(Func, Ops); CI->setDebugLoc(DL); @@ -1543,7 +1589,7 @@ static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val, Type *Tys[] = {Val->getType()}; Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); - Value *Func = Intrinsic::getDeclaration(M, IID, Tys); + Function *Func = Intrinsic::getDeclaration(M, IID, Tys); CallInst *CI = IRBuilder.CreateCall(Func, Ops); CI->setDebugLoc(DL); diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index 6f7dc2429c09..31191b52895c 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -1,9 +1,8 @@ //===- LoopInstSimplify.cpp - Loop Instruction Simplification Pass --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -234,6 +233,8 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, auto PA = getLoopPassPreservedAnalyses(); PA.preserveSet<CFGAnalyses>(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 766e39b439a0..9a42365adc1b 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1,9 +1,8 @@ //===- LoopInterchange.cpp - Loop interchange pass-------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -1265,9 +1264,7 @@ bool LoopInterchangeTransform::transform() { } void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { - BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); - BasicBlock *InnerLoopLatchPred = InnerLoopLatch; - InnerLoopLatch = SplitBlock(InnerLoopLatchPred, Inc, DT, LI); + SplitBlock(InnerLoop->getLoopLatch(), Inc, DT, LI); } /// \brief Move all instructions except the terminator from FromBB right before @@ -1280,17 +1277,6 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { FromBB->getTerminator()->getIterator()); } -static void updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, - BasicBlock *NewPred) { - for (PHINode &PHI : CurrBlock->phis()) { - unsigned Num = PHI.getNumIncomingValues(); - for (unsigned i = 0; i < Num; ++i) { - if (PHI.getIncomingBlock(i) == OldPred) - PHI.setIncomingBlock(i, NewPred); - } - } -} - /// Update BI to jump to NewBB instead of OldBB. Records updates to /// the dominator tree in DTUpdates, if DT should be preserved. static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, @@ -1313,8 +1299,41 @@ static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, } // Move Lcssa PHIs to the right place. -static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerLatch, - BasicBlock *OuterLatch) { +static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerHeader, + BasicBlock *InnerLatch, BasicBlock *OuterHeader, + BasicBlock *OuterLatch, BasicBlock *OuterExit) { + + // Deal with LCSSA PHI nodes in the exit block of the inner loop, that are + // defined either in the header or latch. Those blocks will become header and + // latch of the new outer loop, and the only possible users can PHI nodes + // in the exit block of the loop nest or the outer loop header (reduction + // PHIs, in that case, the incoming value must be defined in the inner loop + // header). We can just substitute the user with the incoming value and remove + // the PHI. + for (PHINode &P : make_early_inc_range(InnerExit->phis())) { + assert(P.getNumIncomingValues() == 1 && + "Only loops with a single exit are supported!"); + + // Incoming values are guaranteed be instructions currently. + auto IncI = cast<Instruction>(P.getIncomingValueForBlock(InnerLatch)); + // Skip phis with incoming values from the inner loop body, excluding the + // header and latch. + if (IncI->getParent() != InnerLatch && IncI->getParent() != InnerHeader) + continue; + + assert(all_of(P.users(), + [OuterHeader, OuterExit, IncI, InnerHeader](User *U) { + return (cast<PHINode>(U)->getParent() == OuterHeader && + IncI->getParent() == InnerHeader) || + cast<PHINode>(U)->getParent() == OuterExit; + }) && + "Can only replace phis iff the uses are in the loop nest exit or " + "the incoming value is defined in the inner header (it will " + "dominate all loop blocks after interchanging)"); + P.replaceAllUsesWith(IncI); + P.eraseFromParent(); + } + SmallVector<PHINode *, 8> LcssaInnerExit; for (PHINode &P : InnerExit->phis()) LcssaInnerExit.push_back(&P); @@ -1327,35 +1346,43 @@ static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerLatch, // If a PHI node has users outside of InnerExit, it has a use outside the // interchanged loop and we have to preserve it. We move these to // InnerLatch, which will become the new exit block for the innermost - // loop after interchanging. For PHIs only used in InnerExit, we can just - // replace them with the incoming value. - for (PHINode *P : LcssaInnerExit) { - bool hasUsersOutside = false; - for (auto UI = P->use_begin(), E = P->use_end(); UI != E;) { - Use &U = *UI; - ++UI; - auto *Usr = cast<Instruction>(U.getUser()); - if (Usr->getParent() != InnerExit) { - hasUsersOutside = true; - continue; - } - U.set(P->getIncomingValueForBlock(InnerLatch)); - } - if (hasUsersOutside) - P->moveBefore(InnerLatch->getFirstNonPHI()); - else - P->eraseFromParent(); - } + // loop after interchanging. + for (PHINode *P : LcssaInnerExit) + P->moveBefore(InnerLatch->getFirstNonPHI()); // If the inner loop latch contains LCSSA PHIs, those come from a child loop // and we have to move them to the new inner latch. for (PHINode *P : LcssaInnerLatch) P->moveBefore(InnerExit->getFirstNonPHI()); + // Deal with LCSSA PHI nodes in the loop nest exit block. For PHIs that have + // incoming values from the outer latch or header, we have to add a new PHI + // in the inner loop latch, which became the exit block of the outer loop, + // after interchanging. + if (OuterExit) { + for (PHINode &P : OuterExit->phis()) { + if (P.getNumIncomingValues() != 1) + continue; + // Skip Phis with incoming values not defined in the outer loop's header + // and latch. Also skip incoming phis defined in the latch. Those should + // already have been updated. + auto I = dyn_cast<Instruction>(P.getIncomingValue(0)); + if (!I || ((I->getParent() != OuterLatch || isa<PHINode>(I)) && + I->getParent() != OuterHeader)) + continue; + + PHINode *NewPhi = dyn_cast<PHINode>(P.clone()); + NewPhi->setIncomingValue(0, P.getIncomingValue(0)); + NewPhi->setIncomingBlock(0, OuterLatch); + NewPhi->insertBefore(InnerLatch->getFirstNonPHI()); + P.setIncomingValue(0, NewPhi); + } + } + // Now adjust the incoming blocks for the LCSSA PHIs. // For PHIs moved from Inner's exit block, we need to replace Inner's latch // with the new latch. - updateIncomingBlock(InnerLatch, InnerLatch, OuterLatch); + InnerLatch->replacePhiUsesWith(InnerLatch, OuterLatch); } bool LoopInterchangeTransform::adjustLoopBranches() { @@ -1374,9 +1401,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() { // preheaders do not satisfy those conditions. if (isa<PHINode>(OuterLoopPreHeader->begin()) || !OuterLoopPreHeader->getUniquePredecessor()) - OuterLoopPreHeader = InsertPreheaderForLoop(OuterLoop, DT, LI, true); + OuterLoopPreHeader = + InsertPreheaderForLoop(OuterLoop, DT, LI, nullptr, true); if (InnerLoopPreHeader == OuterLoop->getHeader()) - InnerLoopPreHeader = InsertPreheaderForLoop(InnerLoop, DT, LI, true); + InnerLoopPreHeader = + InsertPreheaderForLoop(InnerLoop, DT, LI, nullptr, true); // Adjust the loop preheader BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); @@ -1422,8 +1451,8 @@ bool LoopInterchangeTransform::adjustLoopBranches() { InnerLoopHeaderSuccessor, DTUpdates); // Adjust reduction PHI's now that the incoming block has changed. - updateIncomingBlock(InnerLoopHeaderSuccessor, InnerLoopHeader, - OuterLoopHeader); + InnerLoopHeaderSuccessor->replacePhiUsesWith(InnerLoopHeader, + OuterLoopHeader); updateSuccessor(InnerLoopHeaderBI, InnerLoopHeaderSuccessor, OuterLoopPreHeader, DTUpdates); @@ -1452,10 +1481,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() { restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, OuterLoopPreHeader); - moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopLatch, OuterLoopLatch); + moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopHeader, InnerLoopLatch, + OuterLoopHeader, OuterLoopLatch, InnerLoop->getExitBlock()); // For PHIs in the exit block of the outer loop, outer's latch has been // replaced by Inners'. - updateIncomingBlock(OuterLoopLatchSuccessor, OuterLoopLatch, InnerLoopLatch); + OuterLoopLatchSuccessor->replacePhiUsesWith(OuterLoopLatch, InnerLoopLatch); // Now update the reduction PHIs in the inner and outer loop headers. SmallVector<PHINode *, 4> InnerLoopPHIs, OuterLoopPHIs; @@ -1482,10 +1512,10 @@ bool LoopInterchangeTransform::adjustLoopBranches() { } // Update the incoming blocks for moved PHI nodes. - updateIncomingBlock(OuterLoopHeader, InnerLoopPreHeader, OuterLoopPreHeader); - updateIncomingBlock(OuterLoopHeader, InnerLoopLatch, OuterLoopLatch); - updateIncomingBlock(InnerLoopHeader, OuterLoopPreHeader, InnerLoopPreHeader); - updateIncomingBlock(InnerLoopHeader, OuterLoopLatch, InnerLoopLatch); + OuterLoopHeader->replacePhiUsesWith(InnerLoopPreHeader, OuterLoopPreHeader); + OuterLoopHeader->replacePhiUsesWith(InnerLoopLatch, OuterLoopLatch); + InnerLoopHeader->replacePhiUsesWith(OuterLoopPreHeader, InnerLoopPreHeader); + InnerLoopHeader->replacePhiUsesWith(OuterLoopLatch, InnerLoopLatch); return true; } diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index 19bd9ebcc15b..2b3d5e0ce9b7 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -1,9 +1,8 @@ //===- LoopLoadElimination.cpp - Loop Load Elimination Pass ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -30,10 +29,14 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -54,6 +57,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/SizeOpts.h" #include <algorithm> #include <cassert> #include <forward_list> @@ -159,8 +163,9 @@ namespace { class LoadEliminationForLoop { public: LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI, - DominatorTree *DT) - : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.getPSE()) {} + DominatorTree *DT, BlockFrequencyInfo *BFI, + ProfileSummaryInfo* PSI) + : L(L), LI(LI), LAI(LAI), DT(DT), BFI(BFI), PSI(PSI), PSE(LAI.getPSE()) {} /// Look through the loop-carried and loop-independent dependences in /// this loop and find store->load dependences. @@ -428,9 +433,9 @@ public: auto *PH = L->getLoopPreheader(); Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(), PH->getTerminator()); - Value *Initial = - new LoadInst(InitialPtr, "load_initial", /* isVolatile */ false, - Cand.Load->getAlignment(), PH->getTerminator()); + Value *Initial = new LoadInst( + Cand.Load->getType(), InitialPtr, "load_initial", + /* isVolatile */ false, Cand.Load->getAlignment(), PH->getTerminator()); PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); @@ -529,7 +534,17 @@ public: } if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { - if (L->getHeader()->getParent()->optForSize()) { + if (LAI.hasConvergentOp()) { + LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with " + "convergent calls\n"); + return false; + } + + auto *HeaderBB = L->getHeader(); + auto *F = HeaderBB->getParent(); + bool OptForSize = F->hasOptSize() || + llvm::shouldOptimizeForSize(HeaderBB, PSI, BFI); + if (OptForSize) { LLVM_DEBUG( dbgs() << "Versioning is needed but not allowed when optimizing " "for size.\n"); @@ -572,6 +587,8 @@ private: LoopInfo *LI; const LoopAccessInfo &LAI; DominatorTree *DT; + BlockFrequencyInfo *BFI; + ProfileSummaryInfo *PSI; PredicatedScalarEvolution PSE; }; @@ -579,6 +596,7 @@ private: static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { // Build up a worklist of inner-loops to transform to avoid iterator // invalidation. @@ -597,7 +615,7 @@ eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, bool Changed = false; for (Loop *L : Worklist) { // The actual work is performed by LoadEliminationForLoop. - LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT); + LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT, BFI, PSI); Changed |= LEL.processLoop(); } return Changed; @@ -622,10 +640,14 @@ public: auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + auto *BFI = (PSI && PSI->hasProfileSummary()) ? + &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : + nullptr; // Process each loop nest in the function. return eliminateLoadsAcrossLoops( - F, LI, DT, + F, LI, DT, BFI, PSI, [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); }); } @@ -638,6 +660,8 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); } }; @@ -653,6 +677,8 @@ INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) FunctionPass *llvm::createLoopLoadEliminationPass() { @@ -668,12 +694,18 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto *BFI = (PSI && PSI->hasProfileSummary()) ? + &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; + MemorySSA *MSSA = EnableMSSALoopDependency + ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() + : nullptr; auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); bool Changed = eliminateLoadsAcrossLoops( - F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, - SE, TLI, TTI, nullptr}; + F, LI, DT, BFI, PSI, [&](Loop &L) -> const LoopAccessInfo & { + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, MSSA}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }); diff --git a/lib/Transforms/Scalar/LoopPassManager.cpp b/lib/Transforms/Scalar/LoopPassManager.cpp index 774ad7b945a0..f3bfbd3564ab 100644 --- a/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/lib/Transforms/Scalar/LoopPassManager.cpp @@ -1,9 +1,8 @@ //===- LoopPassManager.cpp - Loop pass management -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 5983c804c0c1..507a1e251ca6 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -1,9 +1,8 @@ //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -179,6 +178,7 @@ #include "llvm/Transforms/Scalar/LoopPredication.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" @@ -194,6 +194,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #define DEBUG_TYPE "loop-predication" @@ -222,24 +223,31 @@ static cl::opt<float> LatchExitProbabilityScale( cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored")); +static cl::opt<bool> PredicateWidenableBranchGuards( + "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, + cl::desc("Whether or not we should predicate guards " + "expressed as widenable branches to deoptimize blocks"), + cl::init(true)); + namespace { -class LoopPredication { - /// Represents an induction variable check: - /// icmp Pred, <induction variable>, <loop invariant limit> - struct LoopICmp { - ICmpInst::Predicate Pred; - const SCEVAddRecExpr *IV; - const SCEV *Limit; - LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, - const SCEV *Limit) - : Pred(Pred), IV(IV), Limit(Limit) {} - LoopICmp() {} - void dump() { - dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV - << ", Limit = " << *Limit << "\n"; - } - }; +/// Represents an induction variable check: +/// icmp Pred, <induction variable>, <loop invariant limit> +struct LoopICmp { + ICmpInst::Predicate Pred; + const SCEVAddRecExpr *IV; + const SCEV *Limit; + LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, + const SCEV *Limit) + : Pred(Pred), IV(IV), Limit(Limit) {} + LoopICmp() {} + void dump() { + dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV + << ", Limit = " << *Limit << "\n"; + } +}; +class LoopPredication { + AliasAnalysis *AA; ScalarEvolution *SE; BranchProbabilityInfo *BPI; @@ -249,58 +257,53 @@ class LoopPredication { LoopICmp LatchCheck; bool isSupportedStep(const SCEV* Step); - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) { - return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0), - ICI->getOperand(1)); - } - Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, - Value *RHS); - + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); Optional<LoopICmp> parseLoopLatchICmp(); - bool CanExpand(const SCEV* S); - Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - Instruction *InsertAt); + /// Return an insertion point suitable for inserting a safe to speculate + /// instruction whose only user will be 'User' which has operands 'Ops'. A + /// trivial result would be the at the User itself, but we try to return a + /// loop invariant location if possible. + Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); + /// Same as above, *except* that this uses the SCEV definition of invariant + /// which is that an expression *can be made* invariant via SCEVExpander. + /// Thus, this version is only suitable for finding an insert point to be be + /// passed to SCEVExpander! + Instruction *findInsertPt(Instruction *User, ArrayRef<const SCEV*> Ops); + + /// Return true if the value is known to produce a single fixed value across + /// all iterations on which it executes. Note that this does not imply + /// speculation safety. That must be established seperately. + bool isLoopInvariantValue(const SCEV* S); + + Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, + ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS); Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); + unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition, + SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); - + bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); // If the loop always exits through another block in the loop, we should not // predicate based on the latch check. For example, the latch check can be a // very coarse grained check and there can be more fine grained exit checks // within the loop. We identify such unprofitable loops through BPI. bool isLoopProfitableToPredicate(); - // When the IV type is wider than the range operand type, we can still do loop - // predication, by generating SCEVs for the range and latch that are of the - // same type. We achieve this by generating a SCEV truncate expression for the - // latch IV. This is done iff truncation of the IV is a safe operation, - // without loss of information. - // Another way to achieve this is by generating a wider type SCEV for the - // range check operand, however, this needs a more involved check that - // operands do not overflow. This can lead to loss of information when the - // range operand is of the form: add i32 %offset, %iv. We need to prove that - // sext(x + y) is same as sext(x) + sext(y). - // This function returns true if we can safely represent the IV type in - // the RangeCheckType without loss of information. - bool isSafeToTruncateWideIVType(Type *RangeCheckType); - // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do - // so. - Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); - public: - LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) - : SE(SE), BPI(BPI){}; + LoopPredication(AliasAnalysis *AA, ScalarEvolution *SE, + BranchProbabilityInfo *BPI) + : AA(AA), SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -322,7 +325,8 @@ public: auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); - LoopPredication LP(SE, &BPI); + auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + LoopPredication LP(AA, SE, &BPI); return LP.runOnLoop(L); } }; @@ -348,16 +352,19 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); - LoopPredication LP(&AR.SE, BPI); + LoopPredication LP(&AR.AA, &AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } -Optional<LoopPredication::LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, - Value *RHS) { +Optional<LoopICmp> +LoopPredication::parseLoopICmp(ICmpInst *ICI) { + auto Pred = ICI->getPredicate(); + auto *LHS = ICI->getOperand(0); + auto *RHS = ICI->getOperand(1); + const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) return None; @@ -380,42 +387,98 @@ LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, } Value *LoopPredication::expandCheck(SCEVExpander &Expander, - IRBuilder<> &Builder, + Instruction *Guard, ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, Instruction *InsertAt) { - // TODO: we can check isLoopEntryGuardedByCond before emitting the check - + const SCEV *RHS) { Type *Ty = LHS->getType(); assert(Ty == RHS->getType() && "expandCheck operands have different types?"); - if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) - return Builder.getTrue(); + if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) { + IRBuilder<> Builder(Guard); + if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) + return Builder.getTrue(); + if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), + LHS, RHS)) + return Builder.getFalse(); + } - Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); - Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); + Value *LHSV = Expander.expandCodeFor(LHS, Ty, findInsertPt(Guard, {LHS})); + Value *RHSV = Expander.expandCodeFor(RHS, Ty, findInsertPt(Guard, {RHS})); + IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV})); return Builder.CreateICmp(Pred, LHSV, RHSV); } -Optional<LoopPredication::LoopICmp> -LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { + +// Returns true if its safe to truncate the IV to RangeCheckType. +// When the IV type is wider than the range operand type, we can still do loop +// predication, by generating SCEVs for the range and latch that are of the +// same type. We achieve this by generating a SCEV truncate expression for the +// latch IV. This is done iff truncation of the IV is a safe operation, +// without loss of information. +// Another way to achieve this is by generating a wider type SCEV for the +// range check operand, however, this needs a more involved check that +// operands do not overflow. This can lead to loss of information when the +// range operand is of the form: add i32 %offset, %iv. We need to prove that +// sext(x + y) is same as sext(x) + sext(y). +// This function returns true if we can safely represent the IV type in +// the RangeCheckType without loss of information. +static bool isSafeToTruncateWideIVType(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { + if (!EnableIVTruncation) + return false; + assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()) > + DL.getTypeSizeInBits(RangeCheckType) && + "Expected latch check IV type to be larger than range check operand " + "type!"); + // The start and end values of the IV should be known. This is to guarantee + // that truncating the wide type will not lose information. + auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); + auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); + if (!Limit || !Start) + return false; + // This check makes sure that the IV does not change sign during loop + // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, + // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the + // IV wraps around, and the truncation of the IV would lose the range of + // iterations between 2^32 and 2^64. + bool Increasing; + if (!SE.isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) + return false; + // The active bits should be less than the bits in the RangeCheckType. This + // guarantees that truncating the latch check to RangeCheckType is a safe + // operation. + auto RangeCheckTypeBitSize = DL.getTypeSizeInBits(RangeCheckType); + return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && + Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; +} + + +// Return an LoopICmp describing a latch check equivlent to LatchCheck but with +// the requested type if safe to do so. May involve the use of a new IV. +static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { auto *LatchType = LatchCheck.IV->getType(); if (RangeCheckType == LatchType) return LatchCheck; // For now, bail out if latch type is narrower than range type. - if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType)) + if (DL.getTypeSizeInBits(LatchType) < DL.getTypeSizeInBits(RangeCheckType)) return None; - if (!isSafeToTruncateWideIVType(RangeCheckType)) + if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) return None; // We can now safely identify the truncated version of the IV and limit for // RangeCheckType. LoopICmp NewLatchCheck; NewLatchCheck.Pred = LatchCheck.Pred; NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( - SE->getTruncateExpr(LatchCheck.IV, RangeCheckType)); + SE.getTruncateExpr(LatchCheck.IV, RangeCheckType)); if (!NewLatchCheck.IV) return None; - NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); + NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType); LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType << "can be represented as range check type:" << *RangeCheckType << "\n"); @@ -428,13 +491,66 @@ bool LoopPredication::isSupportedStep(const SCEV* Step) { return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); } -bool LoopPredication::CanExpand(const SCEV* S) { - return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); +Instruction *LoopPredication::findInsertPt(Instruction *Use, + ArrayRef<Value*> Ops) { + for (Value *Op : Ops) + if (!L->isLoopInvariant(Op)) + return Use; + return Preheader->getTerminator(); +} + +Instruction *LoopPredication::findInsertPt(Instruction *Use, + ArrayRef<const SCEV*> Ops) { + // Subtlety: SCEV considers things to be invariant if the value produced is + // the same across iterations. This is not the same as being able to + // evaluate outside the loop, which is what we actually need here. + for (const SCEV *Op : Ops) + if (!SE->isLoopInvariant(Op, L) || + !isSafeToExpandAt(Op, Preheader->getTerminator(), *SE)) + return Use; + return Preheader->getTerminator(); +} + +bool LoopPredication::isLoopInvariantValue(const SCEV* S) { + // Handling expressions which produce invariant results, but *haven't* yet + // been removed from the loop serves two important purposes. + // 1) Most importantly, it resolves a pass ordering cycle which would + // otherwise need us to iteration licm, loop-predication, and either + // loop-unswitch or loop-peeling to make progress on examples with lots of + // predicable range checks in a row. (Since, in the general case, we can't + // hoist the length checks until the dominating checks have been discharged + // as we can't prove doing so is safe.) + // 2) As a nice side effect, this exposes the value of peeling or unswitching + // much more obviously in the IR. Otherwise, the cost modeling for other + // transforms would end up needing to duplicate all of this logic to model a + // check which becomes predictable based on a modeled peel or unswitch. + // + // The cost of doing so in the worst case is an extra fill from the stack in + // the loop to materialize the loop invariant test value instead of checking + // against the original IV which is presumable in a register inside the loop. + // Such cases are presumably rare, and hint at missing oppurtunities for + // other passes. + + if (SE->isLoopInvariant(S, L)) + // Note: This the SCEV variant, so the original Value* may be within the + // loop even though SCEV has proven it is loop invariant. + return true; + + // Handle a particular important case which SCEV doesn't yet know about which + // shows up in range checks on arrays with immutable lengths. + // TODO: This should be sunk inside SCEV. + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) + if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) + if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) + if (AA->pointsToConstantMemory(LI->getOperand(0)) || + LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) + return true; + return false; } Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( - LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); // Generate the widened condition for the forward loop: // guardStart u< guardLimit && @@ -446,40 +562,61 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( const SCEV *GuardLimit = RangeCheck.Limit; const SCEV *LatchStart = LatchCheck.IV->getStart(); const SCEV *LatchLimit = LatchCheck.Limit; + // Subtlety: We need all the values to be *invariant* across all iterations, + // but we only need to check expansion safety for those which *aren't* + // already guaranteed to dominate the guard. + if (!isLoopInvariantValue(GuardStart) || + !isLoopInvariantValue(GuardLimit) || + !isLoopInvariantValue(LatchStart) || + !isLoopInvariantValue(LatchLimit)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + if (!isSafeToExpandAt(LatchStart, Guard, *SE) || + !isSafeToExpandAt(LatchLimit, Guard, *SE)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } // guardLimit - guardStart + latchStart - 1 const SCEV *RHS = SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); - if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || - !CanExpand(LatchLimit) || !CanExpand(RHS)) { - LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; - } auto LimitCheckPred = ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); - - Instruction *InsertAt = Preheader->getTerminator(); + auto *LimitCheck = - expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt); - auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred, - GuardStart, GuardLimit, InsertAt); + expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS); + auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, + GuardStart, GuardLimit); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( - LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); const SCEV *GuardStart = RangeCheck.IV->getStart(); const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchStart = LatchCheck.IV->getStart(); const SCEV *LatchLimit = LatchCheck.Limit; - if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || - !CanExpand(LatchLimit)) { + // Subtlety: We need all the values to be *invariant* across all iterations, + // but we only need to check expansion safety for those which *aren't* + // already guaranteed to dominate the guard. + if (!isLoopInvariantValue(GuardStart) || + !isLoopInvariantValue(GuardLimit) || + !isLoopInvariantValue(LatchStart) || + !isLoopInvariantValue(LatchLimit)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + if (!isSafeToExpandAt(LatchStart, Guard, *SE) || + !isSafeToExpandAt(LatchLimit, Guard, *SE)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } @@ -497,22 +634,35 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( // guardStart u< guardLimit && // latchLimit <pred> 1. // See the header comment for reasoning of the checks. - Instruction *InsertAt = Preheader->getTerminator(); auto LimitCheckPred = ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); - auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, - GuardStart, GuardLimit, InsertAt); - auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, - SE->getOne(Ty), InsertAt); + auto *FirstIterationCheck = expandCheck(Expander, Guard, + ICmpInst::ICMP_ULT, + GuardStart, GuardLimit); + auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, + SE->getOne(Ty)); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } +static void normalizePredicate(ScalarEvolution *SE, Loop *L, + LoopICmp& RC) { + // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the + // ULT/UGE form for ease of handling by our caller. + if (ICmpInst::isEquality(RC.Pred) && + RC.IV->getStepRecurrence(*SE)->isOne() && + SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit)) + RC.Pred = RC.Pred == ICmpInst::ICMP_NE ? + ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; +} + + /// If ICI can be widened to a loop invariant condition emits the loop /// invariant condition in the loop preheader and return it, otherwise /// returns None. Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder) { + Instruction *Guard) { LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); LLVM_DEBUG(ICI->dump()); @@ -545,7 +695,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, return None; } auto *Ty = RangeCheckIV->getType(); - auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); + auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty); if (!CurrLatchCheckOpt) { LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " "corresponding to range type: " @@ -566,34 +716,27 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, if (Step->isOne()) return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); else { assert(Step->isAllOnesValue() && "Step should be -1!"); return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); } } -bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, - SCEVExpander &Expander) { - LLVM_DEBUG(dbgs() << "Processing guard:\n"); - LLVM_DEBUG(Guard->dump()); - - TotalConsidered++; - - IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); - +unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, + Value *Condition, + SCEVExpander &Expander, + Instruction *Guard) { + unsigned NumWidened = 0; // The guard condition is expected to be in form of: // cond1 && cond2 && cond3 ... // Iterate over subconditions looking for icmp conditions which can be // widened across loop iterations. Widening these conditions remember the // resulting list of subconditions in Checks vector. - SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); + SmallVector<Value *, 4> Worklist(1, Condition); SmallPtrSet<Value *, 4> Visited; - - SmallVector<Value *, 4> Checks; - - unsigned NumWidened = 0; + Value *WideableCond = nullptr; do { Value *Condition = Worklist.pop_back_val(); if (!Visited.insert(Condition).second) @@ -607,8 +750,16 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, continue; } + if (match(Condition, + m_Intrinsic<Intrinsic::experimental_widenable_condition>())) { + // Pick any, we don't care which + WideableCond = Condition; + continue; + } + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { - if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { + if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, + Guard)) { Checks.push_back(NewRangeCheck.getValue()); NumWidened++; continue; @@ -617,28 +768,70 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, // Save the condition as is if we can't widen it Checks.push_back(Condition); - } while (Worklist.size() != 0); + } while (!Worklist.empty()); + // At the moment, our matching logic for wideable conditions implicitly + // assumes we preserve the form: (br (and Cond, WC())). FIXME + // Note that if there were multiple calls to wideable condition in the + // traversal, we only need to keep one, and which one is arbitrary. + if (WideableCond) + Checks.push_back(WideableCond); + return NumWidened; +} + +bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, + SCEVExpander &Expander) { + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(Guard->dump()); + TotalConsidered++; + SmallVector<Value *, 4> Checks; + unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander, + Guard); + if (NumWidened == 0) + return false; + + TotalWidened += NumWidened; + + // Emit the new guard condition + IRBuilder<> Builder(findInsertPt(Guard, Checks)); + Value *AllChecks = Builder.CreateAnd(Checks); + auto *OldCond = Guard->getOperand(0); + Guard->setOperand(0, AllChecks); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); + + LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + return true; +} + +bool LoopPredication::widenWidenableBranchGuardConditions( + BranchInst *BI, SCEVExpander &Expander) { + assert(isGuardAsWidenableBranch(BI) && "Must be!"); + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(BI->dump()); + + TotalConsidered++; + SmallVector<Value *, 4> Checks; + unsigned NumWidened = collectChecks(Checks, BI->getCondition(), + Expander, BI); if (NumWidened == 0) return false; TotalWidened += NumWidened; // Emit the new guard condition - Builder.SetInsertPoint(Guard); - Value *LastCheck = nullptr; - for (auto *Check : Checks) - if (!LastCheck) - LastCheck = Check; - else - LastCheck = Builder.CreateAnd(LastCheck, Check); - Guard->setOperand(0, LastCheck); + IRBuilder<> Builder(findInsertPt(BI, Checks)); + Value *AllChecks = Builder.CreateAnd(Checks); + auto *OldCond = BI->getCondition(); + BI->setCondition(AllChecks); + assert(isGuardAsWidenableBranch(BI) && + "Stopped being a guard after transform?"); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; } -Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { +Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { using namespace PatternMatch; BasicBlock *LoopLatch = L->getLoopLatch(); @@ -647,27 +840,30 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { return None; } - ICmpInst::Predicate Pred; - Value *LHS, *RHS; - BasicBlock *TrueDest, *FalseDest; - - if (!match(LoopLatch->getTerminator(), - m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, - FalseDest))) { + auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); + if (!BI || !BI->isConditional()) { LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); return None; } - assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && - "One of the latch's destinations must be the header"); - if (TrueDest != L->getHeader()) - Pred = ICmpInst::getInversePredicate(Pred); - - auto Result = parseLoopICmp(Pred, LHS, RHS); + BasicBlock *TrueDest = BI->getSuccessor(0); + assert( + (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) && + "One of the latch's destinations must be the header"); + + auto *ICI = dyn_cast<ICmpInst>(BI->getCondition()); + if (!ICI) { + LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n"); + return None; + } + auto Result = parseLoopICmp(ICI); if (!Result) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } + if (TrueDest != L->getHeader()) + Result->Pred = ICmpInst::getInversePredicate(Result->Pred); + // Check affine first, so if it's not we don't try to compute the step // recurrence. if (!Result->IV->isAffine()) { @@ -692,49 +888,22 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { } }; + normalizePredicate(SE, L, *Result); if (IsUnsupportedPredicate(Step, Result->Pred)) { LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred << ")!\n"); return None; } + return Result; } -// Returns true if its safe to truncate the IV to RangeCheckType. -bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { - if (!EnableIVTruncation) - return false; - assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) > - DL->getTypeSizeInBits(RangeCheckType) && - "Expected latch check IV type to be larger than range check operand " - "type!"); - // The start and end values of the IV should be known. This is to guarantee - // that truncating the wide type will not lose information. - auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); - auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); - if (!Limit || !Start) - return false; - // This check makes sure that the IV does not change sign during loop - // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, - // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the - // IV wraps around, and the truncation of the IV would lose the range of - // iterations between 2^32 and 2^64. - bool Increasing; - if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) - return false; - // The active bits should be less than the bits in the RangeCheckType. This - // guarantees that truncating the latch check to RangeCheckType is a safe - // operation. - auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType); - return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && - Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; -} bool LoopPredication::isLoopProfitableToPredicate() { if (SkipProfitabilityChecks || !BPI) return true; - SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges; + SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; L->getExitEdges(ExitEdges); // If there is only one exiting edge in the loop, it is always profitable to // predicate the loop. @@ -795,7 +964,12 @@ bool LoopPredication::runOnLoop(Loop *Loop) { // There is nothing to do if the module doesn't use guards auto *GuardDecl = M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); - if (!GuardDecl || GuardDecl->use_empty()) + bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); + auto *WCDecl = M->getFunction( + Intrinsic::getName(Intrinsic::experimental_widenable_condition)); + bool HasWidenableConditions = + PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty(); + if (!HasIntrinsicGuards && !HasWidenableConditions) return false; DL = &M->getDataLayout(); @@ -819,12 +993,18 @@ bool LoopPredication::runOnLoop(Loop *Loop) { // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; - for (const auto BB : L->blocks()) + SmallVector<BranchInst *, 4> GuardsAsWidenableBranches; + for (const auto BB : L->blocks()) { for (auto &I : *BB) if (isGuard(&I)) Guards.push_back(cast<IntrinsicInst>(&I)); + if (PredicateWidenableBranchGuards && + isGuardAsWidenableBranch(BB->getTerminator())) + GuardsAsWidenableBranches.push_back( + cast<BranchInst>(BB->getTerminator())); + } - if (Guards.empty()) + if (Guards.empty() && GuardsAsWidenableBranches.empty()) return false; SCEVExpander Expander(*SE, *DL, "loop-predication"); @@ -832,6 +1012,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) { bool Changed = false; for (auto *Guard : Guards) Changed |= widenGuardConditions(Guard, Expander); + for (auto *Guard : GuardsAsWidenableBranches) + Changed |= widenWidenableBranchGuardConditions(Guard, Expander); return Changed; } diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index 9a99e5925572..166b57f20b43 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -1,9 +1,8 @@ //===- LoopReroll.cpp - Loop rerolling pass -------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -892,12 +891,22 @@ bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) { const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); if (!ADR) return false; + + // Check that the first root is evenly spaced. unsigned N = DRS.Roots.size() + 1; const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR); const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N); if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) return false; + // Check that the remainling roots are evenly spaced. + for (unsigned i = 1; i < N - 1; ++i) { + const SCEV *NewStepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[i]), + SE->getSCEV(DRS.Roots[i-1])); + if (NewStepSCEV != StepSCEV) + return false; + } + return true; } diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index fd22128f7fe6..e009947690af 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -1,9 +1,8 @@ //===- LoopRotation.cpp - Loop Rotation Pass ------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -55,7 +54,10 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, if (AR.MSSA && VerifyMemorySSA) AR.MSSA->verifyMemorySSA(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); + return PA; } namespace { diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 2e5927f9a068..046f4c8af492 100644 --- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -1,9 +1,8 @@ //===--------- LoopSimplifyCFG.cpp - Loop CFG Simplification Pass ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -29,7 +29,6 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -42,7 +41,7 @@ using namespace llvm; #define DEBUG_TYPE "loop-simplifycfg" static cl::opt<bool> EnableTermFolding("enable-loop-simplifycfg-term-folding", - cl::init(false)); + cl::init(true)); STATISTIC(NumTerminatorsFolded, "Number of terminators folded to unconditional branches"); @@ -80,6 +79,36 @@ static BasicBlock *getOnlyLiveSuccessor(BasicBlock *BB) { return nullptr; } +/// Removes \p BB from all loops from [FirstLoop, LastLoop) in parent chain. +static void removeBlockFromLoops(BasicBlock *BB, Loop *FirstLoop, + Loop *LastLoop = nullptr) { + assert((!LastLoop || LastLoop->contains(FirstLoop->getHeader())) && + "First loop is supposed to be inside of last loop!"); + assert(FirstLoop->contains(BB) && "Must be a loop block!"); + for (Loop *Current = FirstLoop; Current != LastLoop; + Current = Current->getParentLoop()) + Current->removeBlockFromLoop(BB); +} + +/// Find innermost loop that contains at least one block from \p BBs and +/// contains the header of loop \p L. +static Loop *getInnermostLoopFor(SmallPtrSetImpl<BasicBlock *> &BBs, + Loop &L, LoopInfo &LI) { + Loop *Innermost = nullptr; + for (BasicBlock *BB : BBs) { + Loop *BBL = LI.getLoopFor(BB); + while (BBL && !BBL->contains(L.getHeader())) + BBL = BBL->getParentLoop(); + if (BBL == &L) + BBL = BBL->getParentLoop(); + if (!BBL) + continue; + if (!Innermost || BBL->getLoopDepth() > Innermost->getLoopDepth()) + Innermost = BBL; + } + return Innermost; +} + namespace { /// Helper class that can turn branches and switches with constant conditions /// into unconditional branches. @@ -90,6 +119,9 @@ private: DominatorTree &DT; ScalarEvolution &SE; MemorySSAUpdater *MSSAU; + LoopBlocksDFS DFS; + DomTreeUpdater DTU; + SmallVector<DominatorTree::UpdateType, 16> DTUpdates; // Whether or not the current loop has irreducible CFG. bool HasIrreducibleCFG = false; @@ -175,7 +207,6 @@ private: /// Fill all information about status of blocks and exits of the current loop /// if constant folding of all branches will be done. void analyze() { - LoopBlocksDFS DFS(&L); DFS.perform(&LI); assert(DFS.isComplete() && "DFS is expected to be finished"); @@ -208,12 +239,13 @@ private: // folding. Only handle blocks from current loop: branches in child loops // are skipped because if they can be folded, they should be folded during // the processing of child loops. - if (TheOnlySucc && LI.getLoopFor(BB) == &L) + bool TakeFoldCandidate = TheOnlySucc && LI.getLoopFor(BB) == &L; + if (TakeFoldCandidate) FoldCandidates.push_back(BB); // Handle successors. for (BasicBlock *Succ : successors(BB)) - if (!TheOnlySucc || TheOnlySucc == Succ) { + if (!TakeFoldCandidate || TheOnlySucc == Succ) { if (L.contains(Succ)) LiveLoopBlocks.insert(Succ); else @@ -229,8 +261,10 @@ private: // Now, all exit blocks that are not marked as live are dead. SmallVector<BasicBlock *, 8> ExitBlocks; L.getExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 8> UniqueDeadExits; for (auto *ExitBlock : ExitBlocks) - if (!LiveExitBlocks.count(ExitBlock)) + if (!LiveExitBlocks.count(ExitBlock) && + UniqueDeadExits.insert(ExitBlock).second) DeadExitBlocks.push_back(ExitBlock); // Whether or not the edge From->To will still be present in graph after the @@ -239,7 +273,7 @@ private: if (!LiveLoopBlocks.count(From)) return false; BasicBlock *TheOnlySucc = getOnlyLiveSuccessor(From); - return !TheOnlySucc || TheOnlySucc == To; + return !TheOnlySucc || TheOnlySucc == To || LI.getLoopFor(From) != &L; }; // The loop will not be destroyed if its latch is live. @@ -317,14 +351,10 @@ private: // Construct split preheader and the dummy switch to thread edges from it to // dead exits. - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); BasicBlock *Preheader = L.getLoopPreheader(); - BasicBlock *NewPreheader = Preheader->splitBasicBlock( - Preheader->getTerminator(), - Twine(Preheader->getName()).concat("-split")); - DTU.deleteEdge(Preheader, L.getHeader()); - DTU.insertEdge(NewPreheader, L.getHeader()); - DTU.insertEdge(Preheader, NewPreheader); + BasicBlock *NewPreheader = llvm::SplitBlock( + Preheader, Preheader->getTerminator(), &DT, &LI, MSSAU); + IRBuilder<> Builder(Preheader->getTerminator()); SwitchInst *DummySwitch = Builder.CreateSwitch(Builder.getInt32(0), NewPreheader); @@ -343,75 +373,106 @@ private: } assert(DummyIdx != 0 && "Too many dead exits!"); DummySwitch->addCase(Builder.getInt32(DummyIdx++), BB); - DTU.insertEdge(Preheader, BB); + DTUpdates.push_back({DominatorTree::Insert, Preheader, BB}); ++NumLoopExitsDeleted; } assert(L.getLoopPreheader() == NewPreheader && "Malformed CFG?"); if (Loop *OuterLoop = LI.getLoopFor(Preheader)) { - OuterLoop->addBasicBlockToLoop(NewPreheader, LI); - // When we break dead edges, the outer loop may become unreachable from // the current loop. We need to fix loop info accordingly. For this, we // find the most nested loop that still contains L and remove L from all // loops that are inside of it. - Loop *StillReachable = nullptr; - for (BasicBlock *BB : LiveExitBlocks) { - Loop *BBL = LI.getLoopFor(BB); - if (BBL && BBL->contains(L.getHeader())) - if (!StillReachable || - BBL->getLoopDepth() > StillReachable->getLoopDepth()) - StillReachable = BBL; - } + Loop *StillReachable = getInnermostLoopFor(LiveExitBlocks, L, LI); // Okay, our loop is no longer in the outer loop (and maybe not in some of // its parents as well). Make the fixup. if (StillReachable != OuterLoop) { LI.changeLoopFor(NewPreheader, StillReachable); - for (Loop *NotContaining = OuterLoop; NotContaining != StillReachable; - NotContaining = NotContaining->getParentLoop()) { - NotContaining->removeBlockFromLoop(NewPreheader); - for (auto *BB : L.blocks()) - NotContaining->removeBlockFromLoop(BB); - } + removeBlockFromLoops(NewPreheader, OuterLoop, StillReachable); + for (auto *BB : L.blocks()) + removeBlockFromLoops(BB, OuterLoop, StillReachable); OuterLoop->removeChildLoop(&L); if (StillReachable) StillReachable->addChildLoop(&L); else LI.addTopLevelLoop(&L); + + // Some values from loops in [OuterLoop, StillReachable) could be used + // in the current loop. Now it is not their child anymore, so such uses + // require LCSSA Phis. + Loop *FixLCSSALoop = OuterLoop; + while (FixLCSSALoop->getParentLoop() != StillReachable) + FixLCSSALoop = FixLCSSALoop->getParentLoop(); + assert(FixLCSSALoop && "Should be a loop!"); + // We need all DT updates to be done before forming LCSSA. + DTU.applyUpdates(DTUpdates); + if (MSSAU) + MSSAU->applyUpdates(DTUpdates, DT); + DTUpdates.clear(); + formLCSSARecursively(*FixLCSSALoop, DT, &LI, &SE); } } + + if (MSSAU) { + // Clear all updates now. Facilitates deletes that follow. + DTU.applyUpdates(DTUpdates); + MSSAU->applyUpdates(DTUpdates, DT); + DTUpdates.clear(); + if (VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + } } /// Delete loop blocks that have become unreachable after folding. Make all /// relevant updates to DT and LI. void deleteDeadLoopBlocks() { - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); if (MSSAU) { - SmallPtrSet<BasicBlock *, 8> DeadLoopBlocksSet(DeadLoopBlocks.begin(), - DeadLoopBlocks.end()); + SmallSetVector<BasicBlock *, 8> DeadLoopBlocksSet(DeadLoopBlocks.begin(), + DeadLoopBlocks.end()); MSSAU->removeBlocks(DeadLoopBlocksSet); } + + // The function LI.erase has some invariants that need to be preserved when + // it tries to remove a loop which is not the top-level loop. In particular, + // it requires loop's preheader to be strictly in loop's parent. We cannot + // just remove blocks one by one, because after removal of preheader we may + // break this invariant for the dead loop. So we detatch and erase all dead + // loops beforehand. + for (auto *BB : DeadLoopBlocks) + if (LI.isLoopHeader(BB)) { + assert(LI.getLoopFor(BB) != &L && "Attempt to remove current loop!"); + Loop *DL = LI.getLoopFor(BB); + if (DL->getParentLoop()) { + for (auto *PL = DL->getParentLoop(); PL; PL = PL->getParentLoop()) + for (auto *BB : DL->getBlocks()) + PL->removeBlockFromLoop(BB); + DL->getParentLoop()->removeChildLoop(DL); + LI.addTopLevelLoop(DL); + } + LI.erase(DL); + } + for (auto *BB : DeadLoopBlocks) { assert(BB != L.getHeader() && "Header of the current loop cannot be dead!"); LLVM_DEBUG(dbgs() << "Deleting dead loop block " << BB->getName() << "\n"); - if (LI.isLoopHeader(BB)) { - assert(LI.getLoopFor(BB) != &L && "Attempt to remove current loop!"); - LI.erase(LI.getLoopFor(BB)); - } LI.removeBlock(BB); - DeleteDeadBlock(BB, &DTU); - ++NumLoopBlocksDeleted; } + + DetatchDeadBlocks(DeadLoopBlocks, &DTUpdates, /*KeepOneInputPHIs*/true); + DTU.applyUpdates(DTUpdates); + DTUpdates.clear(); + for (auto *BB : DeadLoopBlocks) + DTU.deleteBB(BB); + + NumLoopBlocksDeleted += DeadLoopBlocks.size(); } /// Constant-fold terminators of blocks acculumated in FoldCandidates into the /// unconditional branches. void foldTerminators() { - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); - for (BasicBlock *BB : FoldCandidates) { assert(LI.getLoopFor(BB) == &L && "Should be a loop block!"); BasicBlock *TheOnlySucc = getOnlyLiveSuccessor(BB); @@ -453,7 +514,7 @@ private: Term->eraseFromParent(); for (auto *DeadSucc : DeadSuccessors) - DTU.deleteEdge(BB, DeadSucc); + DTUpdates.push_back({DominatorTree::Delete, BB, DeadSucc}); ++NumTerminatorsFolded; } @@ -463,15 +524,18 @@ public: ConstantTerminatorFoldingImpl(Loop &L, LoopInfo &LI, DominatorTree &DT, ScalarEvolution &SE, MemorySSAUpdater *MSSAU) - : L(L), LI(LI), DT(DT), SE(SE), MSSAU(MSSAU) {} + : L(L), LI(LI), DT(DT), SE(SE), MSSAU(MSSAU), DFS(&L), + DTU(DT, DomTreeUpdater::UpdateStrategy::Eager) {} bool run() { assert(L.getLoopLatch() && "Should be single latch!"); // Collect all available information about status of blocks after constant // folding. analyze(); + BasicBlock *Header = L.getHeader(); + (void)Header; - LLVM_DEBUG(dbgs() << "In function " << L.getHeader()->getParent()->getName() + LLVM_DEBUG(dbgs() << "In function " << Header->getParent()->getName() << ": "); if (HasIrreducibleCFG) { @@ -483,7 +547,7 @@ public: if (FoldCandidates.empty()) { LLVM_DEBUG( dbgs() << "No constant terminator folding candidates found in loop " - << L.getHeader()->getName() << "\n"); + << Header->getName() << "\n"); return false; } @@ -491,8 +555,7 @@ public: if (DeleteCurrentLoop) { LLVM_DEBUG( dbgs() - << "Give up constant terminator folding in loop " - << L.getHeader()->getName() + << "Give up constant terminator folding in loop " << Header->getName() << ": we don't currently support deletion of the current loop.\n"); return false; } @@ -503,8 +566,7 @@ public: L.getNumBlocks()) { LLVM_DEBUG( dbgs() << "Give up constant terminator folding in loop " - << L.getHeader()->getName() - << ": we don't currently" + << Header->getName() << ": we don't currently" " support blocks that are not dead, but will stop " "being a part of the loop after constant-folding.\n"); return false; @@ -515,8 +577,7 @@ public: LLVM_DEBUG(dump()); LLVM_DEBUG(dbgs() << "Constant-folding " << FoldCandidates.size() - << " terminators in loop " << L.getHeader()->getName() - << "\n"); + << " terminators in loop " << Header->getName() << "\n"); // Make the actual transforms. handleDeadExits(); @@ -524,20 +585,36 @@ public: if (!DeadLoopBlocks.empty()) { LLVM_DEBUG(dbgs() << "Deleting " << DeadLoopBlocks.size() - << " dead blocks in loop " << L.getHeader()->getName() - << "\n"); + << " dead blocks in loop " << Header->getName() << "\n"); deleteDeadLoopBlocks(); + } else { + // If we didn't do updates inside deleteDeadLoopBlocks, do them here. + DTU.applyUpdates(DTUpdates); + DTUpdates.clear(); } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + #ifndef NDEBUG // Make sure that we have preserved all data structures after the transform. - DT.verify(); - assert(DT.isReachableFromEntry(L.getHeader())); +#if defined(EXPENSIVE_CHECKS) + assert(DT.verify(DominatorTree::VerificationLevel::Full) && + "DT broken after transform!"); +#else + assert(DT.verify(DominatorTree::VerificationLevel::Fast) && + "DT broken after transform!"); +#endif + assert(DT.isReachableFromEntry(Header)); LI.verify(DT); #endif return true; } + + bool foldingBreaksCurrentLoop() const { + return DeleteCurrentLoop; + } }; } // namespace @@ -545,7 +622,8 @@ public: /// branches. static bool constantFoldTerminators(Loop &L, DominatorTree &DT, LoopInfo &LI, ScalarEvolution &SE, - MemorySSAUpdater *MSSAU) { + MemorySSAUpdater *MSSAU, + bool &IsLoopDeleted) { if (!EnableTermFolding) return false; @@ -555,7 +633,9 @@ static bool constantFoldTerminators(Loop &L, DominatorTree &DT, LoopInfo &LI, return false; ConstantTerminatorFoldingImpl BranchFolder(L, LI, DT, SE, MSSAU); - return BranchFolder.run(); + bool Changed = BranchFolder.run(); + IsLoopDeleted = Changed && BranchFolder.foldingBreaksCurrentLoop(); + return Changed; } static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, @@ -587,11 +667,15 @@ static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, } static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, - ScalarEvolution &SE, MemorySSAUpdater *MSSAU) { + ScalarEvolution &SE, MemorySSAUpdater *MSSAU, + bool &isLoopDeleted) { bool Changed = false; // Constant-fold terminators with known constant conditions. - Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU); + Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, isLoopDeleted); + + if (isLoopDeleted) + return true; // Eliminate unconditional branches by merging blocks into their predecessors. Changed |= mergeBlocksIntoPredecessors(L, DT, LI, MSSAU); @@ -604,15 +688,23 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, - LPMUpdater &) { + LPMUpdater &LPMU) { Optional<MemorySSAUpdater> MSSAU; if (EnableMSSALoopDependency && AR.MSSA) MSSAU = MemorySSAUpdater(AR.MSSA); + bool DeleteCurrentLoop = false; if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, + DeleteCurrentLoop)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + if (DeleteCurrentLoop) + LPMU.markLoopAsDeleted(L, "loop-simplifycfg"); + + auto PA = getLoopPassPreservedAnalyses(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); + return PA; } namespace { @@ -623,7 +715,7 @@ public: initializeLoopSimplifyCFGLegacyPassPass(*PassRegistry::getPassRegistry()); } - bool runOnLoop(Loop *L, LPPassManager &) override { + bool runOnLoop(Loop *L, LPPassManager &LPM) override { if (skipLoop(L)) return false; @@ -637,8 +729,13 @@ public: if (VerifyMemorySSA) MSSA->verifyMemorySSA(); } - return simplifyLoopCFG(*L, DT, LI, SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + bool DeleteCurrentLoop = false; + bool Changed = simplifyLoopCFG( + *L, DT, LI, SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, + DeleteCurrentLoop); + if (DeleteCurrentLoop) + LPM.markLoopAsDeleted(*L); + return Changed; } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp index 2f7ad2126ed3..975452e13f09 100644 --- a/lib/Transforms/Scalar/LoopSink.cpp +++ b/lib/Transforms/Scalar/LoopSink.cpp @@ -1,9 +1,8 @@ //===-- LoopSink.cpp - Loop Sink Pass -------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -291,10 +290,9 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, ColdLoopBBs.push_back(B); LoopBlockNumber[B] = ++i; } - std::stable_sort(ColdLoopBBs.begin(), ColdLoopBBs.end(), - [&](BasicBlock *A, BasicBlock *B) { - return BFI.getBlockFreq(A) < BFI.getBlockFreq(B); - }); + llvm::stable_sort(ColdLoopBBs, [&](BasicBlock *A, BasicBlock *B) { + return BFI.getBlockFreq(A) < BFI.getBlockFreq(B); + }); // Traverse preheader's instructions in reverse order becaue if A depends // on B (A appears after B), A needs to be sinked first before B can be diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 773ffb9df0a2..59a387a186b8 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1,9 +1,8 @@ //===- LoopStrengthReduce.cpp - Strength Reduce IVs in Loops --------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -116,6 +115,7 @@ #include <cstdlib> #include <iterator> #include <limits> +#include <numeric> #include <map> #include <utility> @@ -155,11 +155,19 @@ static cl::opt<bool> FilterSameScaledReg( cl::desc("Narrow LSR search space by filtering non-optimal formulae" " with the same ScaledReg and Scale")); +static cl::opt<bool> EnableBackedgeIndexing( + "lsr-backedge-indexing", cl::Hidden, cl::init(true), + cl::desc("Enable the generation of cross iteration indexed memops")); + static cl::opt<unsigned> ComplexityLimit( "lsr-complexity-limit", cl::Hidden, cl::init(std::numeric_limits<uint16_t>::max()), cl::desc("LSR search space complexity limit")); +static cl::opt<unsigned> SetupCostDepthLimit( + "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7), + cl::desc("The limit on recursion depth for LSRs setup cost")); + #ifndef NDEBUG // Stress test IV chain generation. static cl::opt<bool> StressIVChain( @@ -1007,10 +1015,15 @@ namespace { /// This class is used to measure and compare candidate formulae. class Cost { + const Loop *L = nullptr; + ScalarEvolution *SE = nullptr; + const TargetTransformInfo *TTI = nullptr; TargetTransformInfo::LSRCost C; public: - Cost() { + Cost() = delete; + Cost(const Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI) : + L(L), SE(&SE), TTI(&TTI) { C.Insns = 0; C.NumRegs = 0; C.AddRecCost = 0; @@ -1021,7 +1034,7 @@ public: C.ScaleCost = 0; } - bool isLess(Cost &Other, const TargetTransformInfo &TTI); + bool isLess(Cost &Other); void Lose(); @@ -1040,12 +1053,9 @@ public: return C.NumRegs == ~0u; } - void RateFormula(const TargetTransformInfo &TTI, - const Formula &F, + void RateFormula(const Formula &F, SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs = nullptr); @@ -1053,17 +1063,11 @@ public: void dump() const; private: - void RateRegister(const SCEV *Reg, - SmallPtrSetImpl<const SCEV *> &Regs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, - const TargetTransformInfo &TTI); - void RatePrimaryRegister(const SCEV *Reg, + void RateRegister(const Formula &F, const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs); + void RatePrimaryRegister(const Formula &F, const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSetImpl<const SCEV *> *LoserRegs, - const TargetTransformInfo &TTI); + SmallPtrSetImpl<const SCEV *> *LoserRegs); }; /// An operand value in an instruction which is to be replaced with some @@ -1208,19 +1212,36 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, bool HasBaseReg, int64_t Scale, Instruction *Fixup = nullptr); +static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) { + if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg)) + return 1; + if (Depth == 0) + return 0; + if (const auto *S = dyn_cast<SCEVAddRecExpr>(Reg)) + return getSetupCost(S->getStart(), Depth - 1); + if (auto S = dyn_cast<SCEVCastExpr>(Reg)) + return getSetupCost(S->getOperand(), Depth - 1); + if (auto S = dyn_cast<SCEVNAryExpr>(Reg)) + return std::accumulate(S->op_begin(), S->op_end(), 0, + [&](unsigned i, const SCEV *Reg) { + return i + getSetupCost(Reg, Depth - 1); + }); + if (auto S = dyn_cast<SCEVUDivExpr>(Reg)) + return getSetupCost(S->getLHS(), Depth - 1) + + getSetupCost(S->getRHS(), Depth - 1); + return 0; +} + /// Tally up interesting quantities from the given register. -void Cost::RateRegister(const SCEV *Reg, - SmallPtrSetImpl<const SCEV *> &Regs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, - const TargetTransformInfo &TTI) { +void Cost::RateRegister(const Formula &F, const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs) { if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { // If this is an addrec for another loop, it should be an invariant // with respect to L since L is the innermost loop (at least // for now LSR only handles innermost loops). if (AR->getLoop() != L) { // If the AddRec exists, consider it's register free and leave it alone. - if (isExistingPhi(AR, SE)) + if (isExistingPhi(AR, *SE)) return; // It is bad to allow LSR for current loop to add induction variables @@ -1236,16 +1257,24 @@ void Cost::RateRegister(const SCEV *Reg, } unsigned LoopCost = 1; - if (TTI.shouldFavorPostInc()) { - const SCEV *LoopStep = AR->getStepRecurrence(SE); - if (isa<SCEVConstant>(LoopStep)) { - // Check if a post-indexed load/store can be used. - if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) || - TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) { + if (TTI->isIndexedLoadLegal(TTI->MIM_PostInc, AR->getType()) || + TTI->isIndexedStoreLegal(TTI->MIM_PostInc, AR->getType())) { + + // If the step size matches the base offset, we could use pre-indexed + // addressing. + if (TTI->shouldFavorBackedgeIndex(L)) { + if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE))) + if (Step->getAPInt() == F.BaseOffset) + LoopCost = 0; + } + + if (TTI->shouldFavorPostInc()) { + const SCEV *LoopStep = AR->getStepRecurrence(*SE); + if (isa<SCEVConstant>(LoopStep)) { const SCEV *LoopStart = AR->getStart(); if (!isa<SCEVConstant>(LoopStart) && - SE.isLoopInvariant(LoopStart, L)) - LoopCost = 0; + SE->isLoopInvariant(LoopStart, L)) + LoopCost = 0; } } } @@ -1255,7 +1284,7 @@ void Cost::RateRegister(const SCEV *Reg, // TODO: The non-affine case isn't precisely modeled here. if (!AR->isAffine() || !isa<SCEVConstant>(AR->getOperand(1))) { if (!Regs.count(AR->getOperand(1))) { - RateRegister(AR->getOperand(1), Regs, L, SE, DT, TTI); + RateRegister(F, AR->getOperand(1), Regs); if (isLoser()) return; } @@ -1265,43 +1294,34 @@ void Cost::RateRegister(const SCEV *Reg, // Rough heuristic; favor registers which don't require extra setup // instructions in the preheader. - if (!isa<SCEVUnknown>(Reg) && - !isa<SCEVConstant>(Reg) && - !(isa<SCEVAddRecExpr>(Reg) && - (isa<SCEVUnknown>(cast<SCEVAddRecExpr>(Reg)->getStart()) || - isa<SCEVConstant>(cast<SCEVAddRecExpr>(Reg)->getStart())))) - ++C.SetupCost; + C.SetupCost += getSetupCost(Reg, SetupCostDepthLimit); + // Ensure we don't, even with the recusion limit, produce invalid costs. + C.SetupCost = std::min<unsigned>(C.SetupCost, 1 << 16); C.NumIVMuls += isa<SCEVMulExpr>(Reg) && - SE.hasComputableLoopEvolution(Reg, L); + SE->hasComputableLoopEvolution(Reg, L); } /// Record this register in the set. If we haven't seen it before, rate /// it. Optional LoserRegs provides a way to declare any formula that refers to /// one of those regs an instant loser. -void Cost::RatePrimaryRegister(const SCEV *Reg, +void Cost::RatePrimaryRegister(const Formula &F, const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSetImpl<const SCEV *> *LoserRegs, - const TargetTransformInfo &TTI) { + SmallPtrSetImpl<const SCEV *> *LoserRegs) { if (LoserRegs && LoserRegs->count(Reg)) { Lose(); return; } if (Regs.insert(Reg).second) { - RateRegister(Reg, Regs, L, SE, DT, TTI); + RateRegister(F, Reg, Regs); if (LoserRegs && isLoser()) LoserRegs->insert(Reg); } } -void Cost::RateFormula(const TargetTransformInfo &TTI, - const Formula &F, +void Cost::RateFormula(const Formula &F, SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, - const Loop *L, - ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs) { assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula"); @@ -1314,7 +1334,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, Lose(); return; } - RatePrimaryRegister(ScaledReg, Regs, L, SE, DT, LoserRegs, TTI); + RatePrimaryRegister(F, ScaledReg, Regs, LoserRegs); if (isLoser()) return; } @@ -1323,7 +1343,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, Lose(); return; } - RatePrimaryRegister(BaseReg, Regs, L, SE, DT, LoserRegs, TTI); + RatePrimaryRegister(F, BaseReg, Regs, LoserRegs); if (isLoser()) return; } @@ -1334,11 +1354,11 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // Do not count the base and a possible second register if the target // allows to fold 2 registers. C.NumBaseAdds += - NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(TTI, LU, F))); + NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(*TTI, LU, F))); C.NumBaseAdds += (F.UnfoldedOffset != 0); // Accumulate non-free scaling amounts. - C.ScaleCost += getScalingFactorCost(TTI, LU, F, *L); + C.ScaleCost += getScalingFactorCost(*TTI, LU, F, *L); // Tally up the non-zero immediates. for (const LSRFixup &Fixup : LU.Fixups) { @@ -1353,7 +1373,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // Check with target if this offset with this instruction is // specifically not supported. if (LU.Kind == LSRUse::Address && Offset != 0 && - !isAMCompletelyFolded(TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + !isAMCompletelyFolded(*TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, Offset, F.HasBaseReg, F.Scale, Fixup.UserInst)) C.NumBaseAdds++; } @@ -1366,7 +1386,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as // additional instruction (at least fill). - unsigned TTIRegNum = TTI.getNumberOfRegisters(false) - 1; + unsigned TTIRegNum = TTI->getNumberOfRegisters(false) - 1; if (C.NumRegs > TTIRegNum) { // Cost already exceeded TTIRegNum, then only newly added register can add // new instructions. @@ -1386,7 +1406,8 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // // For {-10, +, 1}: // i = i + 1; - if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd() && !TTI.canMacroFuseCmp()) + if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd() && + !TTI->canMacroFuseCmp()) C.Insns++; // Each new AddRec adds 1 instruction to calculation. C.Insns += (C.AddRecCost - PrevAddRecCost); @@ -1410,11 +1431,11 @@ void Cost::Lose() { } /// Choose the lower cost. -bool Cost::isLess(Cost &Other, const TargetTransformInfo &TTI) { +bool Cost::isLess(Cost &Other) { if (InsnsCost.getNumOccurrences() > 0 && InsnsCost && C.Insns != Other.C.Insns) return C.Insns < Other.C.Insns; - return TTI.isLSRCostLess(C, Other.C); + return TTI->isLSRCostLess(C, Other.C); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1888,8 +1909,11 @@ class LSRInstance { ScalarEvolution &SE; DominatorTree &DT; LoopInfo &LI; + AssumptionCache &AC; + TargetLibraryInfo &LibInfo; const TargetTransformInfo &TTI; Loop *const L; + bool FavorBackedgeIndex = false; bool Changed = false; /// This is the insert position that the current loop's induction variable @@ -1910,7 +1934,7 @@ class LSRInstance { SmallSetVector<Type *, 4> Types; /// The list of interesting uses. - SmallVector<LSRUse, 16> Uses; + mutable SmallVector<LSRUse, 16> Uses; /// Track which uses use which register candidates. RegUseTracker RegUses; @@ -2025,7 +2049,8 @@ class LSRInstance { public: LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, - LoopInfo &LI, const TargetTransformInfo &TTI); + LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC, + TargetLibraryInfo &LibInfo); bool getChanged() const { return Changed; } @@ -2804,7 +2829,7 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, /// TODO: Consider IVInc free if it's already used in another chains. static bool isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, - ScalarEvolution &SE, const TargetTransformInfo &TTI) { + ScalarEvolution &SE) { if (StressIVChain) return true; @@ -3064,7 +3089,7 @@ void LSRInstance::CollectChains() { for (unsigned UsersIdx = 0, NChains = IVChainVec.size(); UsersIdx < NChains; ++UsersIdx) { if (!isProfitableChain(IVChainVec[UsersIdx], - ChainUsersVec[UsersIdx].FarUsers, SE, TTI)) + ChainUsersVec[UsersIdx].FarUsers, SE)) continue; // Preserve the chain at UsesIdx. if (ChainIdx != UsersIdx) @@ -3078,7 +3103,7 @@ void LSRInstance::CollectChains() { void LSRInstance::FinalizeChain(IVChain &Chain) { assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); - + for (const IVInc &Inc : Chain) { LLVM_DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand); @@ -3100,7 +3125,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, - IncOffset, /*HaseBaseReg=*/false)) + IncOffset, /*HasBaseReg=*/false)) return false; return true; @@ -3210,6 +3235,9 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, } void LSRInstance::CollectFixupsAndInitialFormulae() { + BranchInst *ExitBranch = nullptr; + bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &LibInfo); + for (const IVStrideUse &U : IU) { Instruction *UserInst = U.getUser(); // Skip IV users that are part of profitable IV Chains. @@ -3239,6 +3267,10 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // equality icmps, thanks to IndVarSimplify. if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) if (CI->isEquality()) { + // If CI can be saved in some target, like replaced inside hardware loop + // in PowerPC, no need to generate initial formulae for it. + if (SaveCmp && CI == dyn_cast<ICmpInst>(ExitBranch->getCondition())) + continue; // Swap the operands if needed to put the OperandValToReplace on the // left, for consistency. Value *NV = CI->getOperand(1); @@ -3738,10 +3770,11 @@ void LSRInstance::GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, void LSRInstance::GenerateConstantOffsetsImpl( LSRUse &LU, unsigned LUIdx, const Formula &Base, const SmallVectorImpl<int64_t> &Worklist, size_t Idx, bool IsScaledReg) { - const SCEV *G = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; - for (int64_t Offset : Worklist) { + + auto GenerateOffset = [&](const SCEV *G, int64_t Offset) { Formula F = Base; F.BaseOffset = (uint64_t)Base.BaseOffset - Offset; + if (isLegalUse(TTI, LU.MinOffset - Offset, LU.MaxOffset - Offset, LU.Kind, LU.AccessTy, F)) { // Add the offset to the base register. @@ -3761,7 +3794,35 @@ void LSRInstance::GenerateConstantOffsetsImpl( (void)InsertFormula(LU, LUIdx, F); } + }; + + const SCEV *G = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; + + // With constant offsets and constant steps, we can generate pre-inc + // accesses by having the offset equal the step. So, for access #0 with a + // step of 8, we generate a G - 8 base which would require the first access + // to be ((G - 8) + 8),+,8. The pre-indexed access then updates the pointer + // for itself and hopefully becomes the base for other accesses. This means + // means that a single pre-indexed access can be generated to become the new + // base pointer for each iteration of the loop, resulting in no extra add/sub + // instructions for pointer updating. + if (FavorBackedgeIndex && LU.Kind == LSRUse::Address) { + if (auto *GAR = dyn_cast<SCEVAddRecExpr>(G)) { + if (auto *StepRec = + dyn_cast<SCEVConstant>(GAR->getStepRecurrence(SE))) { + const APInt &StepInt = StepRec->getAPInt(); + int64_t Step = StepInt.isNegative() ? + StepInt.getSExtValue() : StepInt.getZExtValue(); + + for (int64_t Offset : Worklist) { + Offset -= Step; + GenerateOffset(G, Offset); + } + } + } } + for (int64_t Offset : Worklist) + GenerateOffset(G, Offset); int64_t Imm = ExtractImmediate(G, SE); if (G->isZero() || Imm == 0) @@ -3968,9 +4029,27 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) { Formula F = Base; - if (F.ScaledReg) F.ScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy); - for (const SCEV *&BaseReg : F.BaseRegs) - BaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy); + // Sometimes SCEV is able to prove zero during ext transform. It may + // happen if SCEV did not do all possible transforms while creating the + // initial node (maybe due to depth limitations), but it can do them while + // taking ext. + if (F.ScaledReg) { + const SCEV *NewScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy); + if (NewScaledReg->isZero()) + continue; + F.ScaledReg = NewScaledReg; + } + bool HasZeroBaseReg = false; + for (const SCEV *&BaseReg : F.BaseRegs) { + const SCEV *NewBaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy); + if (NewBaseReg->isZero()) { + HasZeroBaseReg = true; + break; + } + BaseReg = NewBaseReg; + } + if (HasZeroBaseReg) + continue; // TODO: This assumes we've done basic processing on all uses and // have an idea what the register usage is. @@ -4067,11 +4146,17 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // Conservatively examine offsets between this orig reg a few selected // other orig regs. + int64_t First = Imms.begin()->first; + int64_t Last = std::prev(Imms.end())->first; + // Compute (First + Last) / 2 without overflow using the fact that + // First + Last = 2 * (First + Last) + (First ^ Last). + int64_t Avg = (First & Last) + ((First ^ Last) >> 1); + // If the result is negative and First is odd and Last even (or vice versa), + // we rounded towards -inf. Add 1 in that case, to round towards 0. + Avg = Avg + ((First ^ Last) & ((uint64_t)Avg >> 63)); ImmMapTy::const_iterator OtherImms[] = { - Imms.begin(), std::prev(Imms.end()), - Imms.lower_bound((Imms.begin()->first + std::prev(Imms.end())->first) / - 2) - }; + Imms.begin(), std::prev(Imms.end()), + Imms.lower_bound(Avg)}; for (size_t i = 0, e = array_lengthof(OtherImms); i != e; ++i) { ImmMapTy::const_iterator M = OtherImms[i]; if (M == J || M == JE) continue; @@ -4249,9 +4334,9 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { // avoids the need to recompute this information across formulae using the // same bad AddRec. Passing LoserRegs is also essential unless we remove // the corresponding bad register from the Regs set. - Cost CostF; + Cost CostF(L, SE, TTI); Regs.clear(); - CostF.RateFormula(TTI, F, Regs, VisitedRegs, L, SE, DT, LU, &LoserRegs); + CostF.RateFormula(F, Regs, VisitedRegs, LU, &LoserRegs); if (CostF.isLoser()) { // During initial formula generation, undesirable formulae are generated // by uses within other loops that have some non-trivial address mode or @@ -4282,10 +4367,10 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { Formula &Best = LU.Formulae[P.first->second]; - Cost CostBest; + Cost CostBest(L, SE, TTI); Regs.clear(); - CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, SE, DT, LU); - if (CostF.isLess(CostBest, TTI)) + CostBest.RateFormula(Best, Regs, VisitedRegs, LU); + if (CostF.isLess(CostBest)) std::swap(F, Best); LLVM_DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); dbgs() << "\n" @@ -4357,7 +4442,9 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { I = F.BaseRegs.begin(), E = F.BaseRegs.end(); I != E; ++I) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(*I)) { Formula NewF = F; - NewF.BaseOffset += C->getValue()->getSExtValue(); + //FIXME: Formulas should store bitwidth to do wrapping properly. + // See PR41034. + NewF.BaseOffset += (uint64_t)C->getValue()->getSExtValue(); NewF.BaseRegs.erase(NewF.BaseRegs.begin() + (I - F.BaseRegs.begin())); if (LU.HasFormulaWithSameRegs(NewF)) { @@ -4400,7 +4487,7 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { /// When there are many registers for expressions like A, A+1, A+2, etc., /// allocate a single register for them. void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { - if (EstimateSearchSpaceComplexity() < ComplexityLimit) + if (EstimateSearchSpaceComplexity() < ComplexityLimit) return; LLVM_DEBUG( @@ -4533,12 +4620,13 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { // If the new register numbers are the same, choose the Formula with // less Cost. - Cost CostFA, CostFB; + Cost CostFA(L, SE, TTI); + Cost CostFB(L, SE, TTI); Regs.clear(); - CostFA.RateFormula(TTI, FA, Regs, VisitedRegs, L, SE, DT, LU); + CostFA.RateFormula(FA, Regs, VisitedRegs, LU); Regs.clear(); - CostFB.RateFormula(TTI, FB, Regs, VisitedRegs, L, SE, DT, LU); - return CostFA.isLess(CostFB, TTI); + CostFB.RateFormula(FB, Regs, VisitedRegs, LU); + return CostFA.isLess(CostFB); }; bool Any = false; @@ -4824,7 +4912,7 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, ReqRegs.insert(S); SmallPtrSet<const SCEV *, 16> NewRegs; - Cost NewCost; + Cost NewCost(L, SE, TTI); for (const Formula &F : LU.Formulae) { // Ignore formulae which may not be ideal in terms of register reuse of // ReqRegs. The formula should use all required registers before @@ -4848,8 +4936,8 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, // the current best, prune the search at that point. NewCost = CurCost; NewRegs = CurRegs; - NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, SE, DT, LU); - if (NewCost.isLess(SolutionCost, TTI)) { + NewCost.RateFormula(F, NewRegs, VisitedRegs, LU); + if (NewCost.isLess(SolutionCost)) { Workspace.push_back(&F); if (Workspace.size() != Uses.size()) { SolveRecurse(Solution, SolutionCost, Workspace, NewCost, @@ -4858,9 +4946,9 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, VisitedRegs.insert(F.ScaledReg ? F.ScaledReg : F.BaseRegs[0]); } else { LLVM_DEBUG(dbgs() << "New best at "; NewCost.print(dbgs()); - dbgs() << ".\n Regs:"; for (const SCEV *S - : NewRegs) dbgs() - << ' ' << *S; + dbgs() << ".\nRegs:\n"; + for (const SCEV *S : NewRegs) dbgs() + << "- " << *S << "\n"; dbgs() << '\n'); SolutionCost = NewCost; @@ -4875,9 +4963,9 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, /// vector. void LSRInstance::Solve(SmallVectorImpl<const Formula *> &Solution) const { SmallVector<const Formula *, 8> Workspace; - Cost SolutionCost; + Cost SolutionCost(L, SE, TTI); SolutionCost.Lose(); - Cost CurCost; + Cost CurCost(L, SE, TTI); SmallPtrSet<const SCEV *, 16> CurRegs; DenseSet<const SCEV *> VisitedRegs; Workspace.reserve(Uses.size()); @@ -5215,6 +5303,7 @@ void LSRInstance::RewriteForPHI( DenseMap<BasicBlock *, Value *> Inserted; for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (PN->getIncomingValue(i) == LF.OperandValToReplace) { + bool needUpdateFixups = false; BasicBlock *BB = PN->getIncomingBlock(i); // If this is a critical edge, split the edge so that we do not insert @@ -5233,7 +5322,7 @@ void LSRInstance::RewriteForPHI( NewBB = SplitCriticalEdge(BB, Parent, CriticalEdgeSplittingOptions(&DT, &LI) .setMergeIdenticalEdges() - .setDontDeleteUselessPHIs()); + .setKeepOneInputPHIs()); } else { SmallVector<BasicBlock*, 2> NewBBs; SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DT, &LI); @@ -5253,6 +5342,8 @@ void LSRInstance::RewriteForPHI( e = PN->getNumIncomingValues(); BB = NewBB; i = PN->getBasicBlockIndex(BB); + + needUpdateFixups = true; } } } @@ -5277,6 +5368,44 @@ void LSRInstance::RewriteForPHI( PN->setIncomingValue(i, FullV); Pair.first->second = FullV; } + + // If LSR splits critical edge and phi node has other pending + // fixup operands, we need to update those pending fixups. Otherwise + // formulae will not be implemented completely and some instructions + // will not be eliminated. + if (needUpdateFixups) { + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) + for (LSRFixup &Fixup : Uses[LUIdx].Fixups) + // If fixup is supposed to rewrite some operand in the phi + // that was just updated, it may be already moved to + // another phi node. Such fixup requires update. + if (Fixup.UserInst == PN) { + // Check if the operand we try to replace still exists in the + // original phi. + bool foundInOriginalPHI = false; + for (const auto &val : PN->incoming_values()) + if (val == Fixup.OperandValToReplace) { + foundInOriginalPHI = true; + break; + } + + // If fixup operand found in original PHI - nothing to do. + if (foundInOriginalPHI) + continue; + + // Otherwise it might be moved to another PHI and requires update. + // If fixup operand not found in any of the incoming blocks that + // means we have already rewritten it - nothing to do. + for (const auto &Block : PN->blocks()) + for (BasicBlock::iterator I = Block->begin(); isa<PHINode>(I); + ++I) { + PHINode *NewPN = cast<PHINode>(I); + for (const auto &val : NewPN->incoming_values()) + if (val == Fixup.OperandValToReplace) + Fixup.UserInst = NewPN; + } + } + } } } @@ -5360,8 +5489,11 @@ void LSRInstance::ImplementSolution( LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, - const TargetTransformInfo &TTI) - : IU(IU), SE(SE), DT(DT), LI(LI), TTI(TTI), L(L) { + const TargetTransformInfo &TTI, AssumptionCache &AC, + TargetLibraryInfo &LibInfo) + : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), LibInfo(LibInfo), TTI(TTI), L(L), + FavorBackedgeIndex(EnableBackedgeIndexing && + TTI.shouldFavorBackedgeIndex(L)) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) return; @@ -5556,6 +5688,8 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); // Requiring LoopSimplify a second time here prevents IVUsers from running // twice, since LoopSimplify was invalidated by running ScalarEvolution. AU.addRequiredID(LoopSimplifyID); @@ -5566,11 +5700,14 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + AssumptionCache &AC, + TargetLibraryInfo &LibInfo) { + bool Changed = false; // Run the main LSR transformation. - Changed |= LSRInstance(L, IU, SE, DT, LI, TTI).getChanged(); + Changed |= LSRInstance(L, IU, SE, DT, LI, TTI, AC, LibInfo).getChanged(); // Remove any extra phis created by processing inner loops. Changed |= DeleteDeadPHIs(L->getHeader()); @@ -5601,14 +5738,17 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( *L->getHeader()->getParent()); - return ReduceLoopStrength(L, IU, SE, DT, LI, TTI); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo); } PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE, - AR.DT, AR.LI, AR.TTI)) + AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); diff --git a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index da46210b6fdd..86891eb451bb 100644 --- a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -1,9 +1,8 @@ //===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -295,7 +294,8 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, return LoopUnrollResult::Unmodified; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, SE, TTI, OptLevel, None, None, None, None, None, None); + L, SE, TTI, nullptr, nullptr, OptLevel, + None, None, None, None, None, None); if (AllowUnrollAndJam.getNumOccurrences() > 0) UP.UnrollAndJam = AllowUnrollAndJam; if (UnrollAndJamThreshold.getNumOccurrences() > 0) diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index 38b80f48ed0e..2fa7436213dd 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -1,9 +1,8 @@ //===- LoopUnroll.cpp - Loop unroller pass --------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -24,7 +23,9 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -56,6 +57,7 @@ #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SizeOpts.h" #include "llvm/Transforms/Utils/UnrollLoop.h" #include <algorithm> #include <cassert> @@ -69,6 +71,12 @@ using namespace llvm; #define DEBUG_TYPE "loop-unroll" +cl::opt<bool> llvm::ForgetSCEVInLoopUnroll( + "forget-scev-loop-unroll", cl::init(false), cl::Hidden, + cl::desc("Forget everything in SCEV when doing LoopUnroll, instead of just" + " the current top-most loop. This is somtimes preferred to reduce" + " compile time.")); + static cl::opt<unsigned> UnrollThreshold("unroll-threshold", cl::Hidden, cl::desc("The cost threshold for loop unrolling")); @@ -166,7 +174,8 @@ static const unsigned NoThreshold = std::numeric_limits<unsigned>::max(); /// Gather the various unrolling parameters based on the defaults, compiler /// flags, TTI overrides and user specified parameters. TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( - Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, int OptLevel, + Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling) { @@ -199,9 +208,12 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( TTI.getUnrollingPreferences(L, SE, UP); // Apply size attributes - if (L->getHeader()->getParent()->optForSize()) { + bool OptForSize = L->getHeader()->getParent()->hasOptSize() || + llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI); + if (OptForSize) { UP.Threshold = UP.OptSizeThreshold; UP.PartialThreshold = UP.PartialOptSizeThreshold; + UP.MaxPercentThresholdBoost = 100; } // Apply any user values specified by cl::opt @@ -964,8 +976,10 @@ bool llvm::computeUnrollCount( static LoopUnrollResult tryToUnrollLoop( Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, - OptimizationRemarkEmitter &ORE, bool PreserveLCSSA, int OptLevel, - bool OnlyWhenForced, Optional<unsigned> ProvidedCount, + OptimizationRemarkEmitter &ORE, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, + bool PreserveLCSSA, int OptLevel, + bool OnlyWhenForced, bool ForgetAllSCEV, Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound, Optional<bool> ProvidedAllowPeeling) { @@ -986,15 +1000,19 @@ static LoopUnrollResult tryToUnrollLoop( if (OnlyWhenForced && !(TM & TM_Enable)) return LoopUnrollResult::Unmodified; + bool OptForSize = L->getHeader()->getParent()->hasOptSize(); unsigned NumInlineCandidates; bool NotDuplicatable; bool Convergent; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, SE, TTI, OptLevel, ProvidedThreshold, ProvidedCount, + L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, ProvidedAllowPeeling); - // Exit early if unrolling is disabled. - if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) + + // Exit early if unrolling is disabled. For OptForSize, we pick the loop size + // as threshold later on. + if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0) && + !OptForSize) return LoopUnrollResult::Unmodified; SmallPtrSet<const Value *, 32> EphValues; @@ -1009,6 +1027,12 @@ static LoopUnrollResult tryToUnrollLoop( << " instructions.\n"); return LoopUnrollResult::Unmodified; } + + // When optimizing for size, use LoopSize as threshold, to (fully) unroll + // loops, if it does not increase code size. + if (OptForSize) + UP.Threshold = std::max(UP.Threshold, LoopSize); + if (NumInlineCandidates != 0) { LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return LoopUnrollResult::Unmodified; @@ -1081,8 +1105,10 @@ static LoopUnrollResult tryToUnrollLoop( // Unroll the loop. Loop *RemainderLoop = nullptr; LoopUnrollResult UnrollResult = UnrollLoop( - L, UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, - UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, + L, + {UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, + UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, + ForgetAllSCEV}, LI, &SE, &DT, &AC, &ORE, PreserveLCSSA, &RemainderLoop); if (UnrollResult == LoopUnrollResult::Unmodified) return LoopUnrollResult::Unmodified; @@ -1132,6 +1158,11 @@ public: /// metadata are considered. All other loops are skipped. bool OnlyWhenForced; + /// If false, when SCEV is invalidated, only forget everything in the + /// top-most loop (call forgetTopMostLoop), of the loop being processed. + /// Otherwise, forgetAllLoops and rebuild when needed next. + bool ForgetAllSCEV; + Optional<unsigned> ProvidedCount; Optional<unsigned> ProvidedThreshold; Optional<bool> ProvidedAllowPartial; @@ -1140,15 +1171,16 @@ public: Optional<bool> ProvidedAllowPeeling; LoopUnroll(int OptLevel = 2, bool OnlyWhenForced = false, - Optional<unsigned> Threshold = None, + bool ForgetAllSCEV = false, Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, Optional<bool> UpperBound = None, Optional<bool> AllowPeeling = None) : LoopPass(ID), OptLevel(OptLevel), OnlyWhenForced(OnlyWhenForced), - ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), - ProvidedAllowPartial(AllowPartial), ProvidedRuntime(Runtime), - ProvidedUpperBound(UpperBound), ProvidedAllowPeeling(AllowPeeling) { + ForgetAllSCEV(ForgetAllSCEV), ProvidedCount(std::move(Count)), + ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), + ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound), + ProvidedAllowPeeling(AllowPeeling) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } @@ -1171,9 +1203,10 @@ public: bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); LoopUnrollResult Result = tryToUnrollLoop( - L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, OnlyWhenForced, - ProvidedCount, ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, - ProvidedUpperBound, ProvidedAllowPeeling); + L, DT, LI, SE, TTI, AC, ORE, nullptr, nullptr, + PreserveLCSSA, OptLevel, OnlyWhenForced, + ForgetAllSCEV, ProvidedCount, ProvidedThreshold, ProvidedAllowPartial, + ProvidedRuntime, ProvidedUpperBound, ProvidedAllowPeeling); if (Result == LoopUnrollResult::FullyUnrolled) LPM.markLoopAsDeleted(*L); @@ -1203,14 +1236,14 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced, - int Threshold, int Count, int AllowPartial, - int Runtime, int UpperBound, + bool ForgetAllSCEV, int Threshold, int Count, + int AllowPartial, int Runtime, int UpperBound, int AllowPeeling) { // TODO: It would make more sense for this function to take the optionals // directly, but that's dangerous since it would silently break out of tree // callers. return new LoopUnroll( - OptLevel, OnlyWhenForced, + OptLevel, OnlyWhenForced, ForgetAllSCEV, Threshold == -1 ? None : Optional<unsigned>(Threshold), Count == -1 ? None : Optional<unsigned>(Count), AllowPartial == -1 ? None : Optional<bool>(AllowPartial), @@ -1219,8 +1252,10 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced, AllowPeeling == -1 ? None : Optional<bool>(AllowPeeling)); } -Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced) { - return createLoopUnrollPass(OptLevel, OnlyWhenForced, -1, -1, 0, 0, 0, 0); +Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced, + bool ForgetAllSCEV) { + return createLoopUnrollPass(OptLevel, OnlyWhenForced, ForgetAllSCEV, -1, -1, + 0, 0, 0, 0); } PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, @@ -1250,8 +1285,9 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + /*BFI*/ nullptr, /*PSI*/ nullptr, /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, - /*Count*/ None, + ForgetSCEV, /*Count*/ None, /*Threshold*/ None, /*AllowPartial*/ false, /*Runtime*/ false, /*UpperBound*/ false, /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified; @@ -1352,6 +1388,8 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto *BFI = (PSI && PSI->hasProfileSummary()) ? + &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; bool Changed = false; @@ -1361,7 +1399,8 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, // will simplify all loops, regardless of whether anything end up being // unrolled. for (auto &L : LI) { - Changed |= simplifyLoop(L, &DT, &LI, &SE, &AC, false /* PreserveLCSSA */); + Changed |= + simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */); Changed |= formLCSSARecursively(*L, DT, &LI, &SE); } @@ -1387,9 +1426,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, // The API here is quite complex to call and we allow to select some // flavors of unrolling during construction time (by setting UnrollOpts). LoopUnrollResult Result = tryToUnrollLoop( - &L, DT, &LI, SE, TTI, AC, ORE, + &L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI, /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced, - /*Count*/ None, + UnrollOpts.ForgetSCEV, /*Count*/ None, /*Threshold*/ None, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling); Changed |= Result != LoopUnrollResult::Unmodified; diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 4a089dfa7dbf..b5b8e720069c 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -1,9 +1,8 @@ //===- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop -------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -658,7 +657,7 @@ bool LoopUnswitch::processCurrentLoop() { } // Do not do non-trivial unswitch while optimizing for size. - // FIXME: Use Function::optForSize(). + // FIXME: Use Function::hasOptSize(). if (OptimizeForSize || loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) return false; @@ -1405,8 +1404,8 @@ static void RemoveFromWorklist(Instruction *I, /// When we find that I really equals V, remove I from the /// program, replacing all uses with V and update the worklist. static void ReplaceUsesOfWith(Instruction *I, Value *V, - std::vector<Instruction*> &Worklist, - Loop *L, LPPassManager *LPM) { + std::vector<Instruction *> &Worklist, Loop *L, + LPPassManager *LPM, MemorySSAUpdater *MSSAU) { LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); // Add uses to the worklist, which may be dead now. @@ -1420,8 +1419,11 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V, LPM->deleteSimpleAnalysisValue(I, L); RemoveFromWorklist(I, Worklist); I->replaceAllUsesWith(V); - if (!I->mayHaveSideEffects()) + if (!I->mayHaveSideEffects()) { + if (MSSAU) + MSSAU->removeMemoryAccess(I); I->eraseFromParent(); + } ++NumSimplify; } @@ -1548,8 +1550,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, ConstantInt::getTrue(Context), NewSISucc); // Release the PHI operands for this edge. for (PHINode &PN : NewSISucc->phis()) - PN.setIncomingValue(PN.getBasicBlockIndex(Switch), - UndefValue::get(PN.getType())); + PN.setIncomingValueForBlock(Switch, UndefValue::get(PN.getType())); // Tell the domtree about the new block. We don't fully update the // domtree here -- instead we force it to do a full recomputation // after the pass is complete -- but we do need to inform it of @@ -1596,7 +1597,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // 'false'. TODO: update the domtree properly so we can pass it here. if (Value *V = SimplifyInstruction(I, DL)) if (LI->replacementPreservesLCSSAForm(I, V)) { - ReplaceUsesOfWith(I, V, Worklist, L, LPM); + ReplaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get()); continue; } @@ -1616,7 +1617,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // Resolve any single entry PHI nodes in Succ. while (PHINode *PN = dyn_cast<PHINode>(Succ->begin())) - ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM); + ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM, + MSSAU.get()); // If Succ has any successors with PHI nodes, update them to have // entries coming from Pred instead of Succ. diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 83861b98fbd8..896dd8bcb922 100644 --- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -1,9 +1,8 @@ //===- LoopVersioningLICM.cpp - LICM Loop Versioning ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -357,14 +356,22 @@ bool LoopVersioningLICM::legalLoopMemoryAccesses() { /// 1) Check all load store in loop body are non atomic & non volatile. /// 2) Check function call safety, by ensuring its not accessing memory. /// 3) Loop body shouldn't have any may throw instruction. +/// 4) Loop body shouldn't have any convergent or noduplicate instructions. bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { assert(I != nullptr && "Null instruction found!"); // Check function call safety - if (auto *Call = dyn_cast<CallBase>(I)) + if (auto *Call = dyn_cast<CallBase>(I)) { + if (Call->isConvergent() || Call->cannotDuplicate()) { + LLVM_DEBUG(dbgs() << " Convergent call site found.\n"); + return false; + } + if (!AA->doesNotAccessMemory(Call)) { LLVM_DEBUG(dbgs() << " Unsafe call site found.\n"); return false; } + } + // Avoid loops with possiblity of throw if (I->mayThrow()) { LLVM_DEBUG(dbgs() << " May throw instruction found in loop body\n"); diff --git a/lib/Transforms/Scalar/LowerAtomic.cpp b/lib/Transforms/Scalar/LowerAtomic.cpp index c165c5ece95c..e076424d9042 100644 --- a/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/lib/Transforms/Scalar/LowerAtomic.cpp @@ -1,9 +1,8 @@ //===- LowerAtomic.cpp - Lower atomic intrinsics --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -27,7 +26,7 @@ static bool LowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) { Value *Cmp = CXI->getCompareOperand(); Value *Val = CXI->getNewValOperand(); - LoadInst *Orig = Builder.CreateLoad(Ptr); + LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); Value *Equal = Builder.CreateICmpEQ(Orig, Cmp); Value *Res = Builder.CreateSelect(Equal, Val, Orig); Builder.CreateStore(Res, Ptr); @@ -45,7 +44,7 @@ static bool LowerAtomicRMWInst(AtomicRMWInst *RMWI) { Value *Ptr = RMWI->getPointerOperand(); Value *Val = RMWI->getValOperand(); - LoadInst *Orig = Builder.CreateLoad(Ptr); + LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); Value *Res = nullptr; switch (RMWI->getOperation()) { @@ -87,6 +86,12 @@ static bool LowerAtomicRMWInst(AtomicRMWInst *RMWI) { Res = Builder.CreateSelect(Builder.CreateICmpULT(Orig, Val), Orig, Val); break; + case AtomicRMWInst::FAdd: + Res = Builder.CreateFAdd(Orig, Val); + break; + case AtomicRMWInst::FSub: + Res = Builder.CreateFSub(Orig, Val); + break; } Builder.CreateStore(Res, Ptr); RMWI->replaceAllUsesWith(Orig); diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 68bfa0030395..0d67c0d740ec 100644 --- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -1,9 +1,8 @@ //===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 4867b33d671f..9489e01774d6 100644 --- a/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -1,9 +1,8 @@ //===- LowerGuardIntrinsic.cpp - Lower the guard intrinsic ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/LowerWidenableCondition.cpp b/lib/Transforms/Scalar/LowerWidenableCondition.cpp new file mode 100644 index 000000000000..5342f2ddcb6b --- /dev/null +++ b/lib/Transforms/Scalar/LowerWidenableCondition.cpp @@ -0,0 +1,85 @@ +//===- LowerWidenableCondition.cpp - Lower the guard intrinsic ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the llvm.widenable.condition intrinsic to default value +// which is i1 true. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerWidenableCondition.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/GuardUtils.h" + +using namespace llvm; + +namespace { +struct LowerWidenableConditionLegacyPass : public FunctionPass { + static char ID; + LowerWidenableConditionLegacyPass() : FunctionPass(ID) { + initializeLowerWidenableConditionLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +}; +} + +static bool lowerWidenableCondition(Function &F) { + // Check if we can cheaply rule out the possibility of not having any work to + // do. + auto *WCDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_widenable_condition)); + if (!WCDecl || WCDecl->use_empty()) + return false; + + using namespace llvm::PatternMatch; + SmallVector<CallInst *, 8> ToLower; + for (auto &I : instructions(F)) + if (match(&I, m_Intrinsic<Intrinsic::experimental_widenable_condition>())) + ToLower.push_back(cast<CallInst>(&I)); + + if (ToLower.empty()) + return false; + + for (auto *CI : ToLower) { + CI->replaceAllUsesWith(ConstantInt::getTrue(CI->getContext())); + CI->eraseFromParent(); + } + return true; +} + +bool LowerWidenableConditionLegacyPass::runOnFunction(Function &F) { + return lowerWidenableCondition(F); +} + +char LowerWidenableConditionLegacyPass::ID = 0; +INITIALIZE_PASS(LowerWidenableConditionLegacyPass, "lower-widenable-condition", + "Lower the widenable condition to default true value", false, + false) + +Pass *llvm::createLowerWidenableConditionPass() { + return new LowerWidenableConditionLegacyPass(); +} + +PreservedAnalyses LowerWidenableConditionPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (lowerWidenableCondition(F)) + return PreservedAnalyses::none(); + + return PreservedAnalyses::all(); +} diff --git a/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/lib/Transforms/Scalar/MakeGuardsExplicit.cpp index 1ba3994eba0e..789232e0f5ce 100644 --- a/lib/Transforms/Scalar/MakeGuardsExplicit.cpp +++ b/lib/Transforms/Scalar/MakeGuardsExplicit.cpp @@ -1,9 +1,8 @@ //===- MakeGuardsExplicit.cpp - Turn guard intrinsics into guard branches -===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index ced923d6973d..5a055139be4f 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -1,9 +1,8 @@ //===- MemCpyOptimizer.cpp - Optimize use of memcpy and friends -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -279,8 +278,8 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, unsigned Alignment, Instruction *Inst) { int64_t End = Start+Size; - range_iterator I = std::lower_bound(Ranges.begin(), Ranges.end(), Start, - [](const MemsetRange &LHS, int64_t RHS) { return LHS.End < RHS; }); + range_iterator I = partition_point( + Ranges, [=](const MemsetRange &O) { return O.End < Start; }); // We now know that I == E, in which case we didn't find anything to merge // with, or that Start <= I->End. If End < I->Start or I == E, then we need @@ -413,7 +412,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (!NextStore->isSimple()) break; // Check to see if this stored value is of the same byte-splattable value. - Value *StoredByte = isBytewiseValue(NextStore->getOperand(0)); + Value *StoredByte = isBytewiseValue(NextStore->getOperand(0), DL); if (isa<UndefValue>(ByteVal) && StoredByte) ByteVal = StoredByte; if (ByteVal != StoredByte) @@ -750,7 +749,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // byte at a time like "0" or "-1" or any width, as well as things like // 0xA0A0A0A0 and 0.0. auto *V = SI->getOperand(0); - if (Value *ByteVal = isBytewiseValue(V)) { + if (Value *ByteVal = isBytewiseValue(V, DL)) { if (Instruction *I = tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) { BBI = I->getIterator(); // Don't invalidate iterator. @@ -1135,8 +1134,10 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); - Builder.CreateMemSet(Builder.CreateGEP(Dest, SrcSize), MemSet->getOperand(1), - MemsetLen, Align); + Builder.CreateMemSet( + Builder.CreateGEP(Dest->getType()->getPointerElementType(), Dest, + SrcSize), + MemSet->getOperand(1), MemsetLen, Align); MD->removeInstruction(MemSet); MemSet->eraseFromParent(); @@ -1228,7 +1229,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { // If copying from a constant, try to turn the memcpy into a memset. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(M->getSource())) if (GV->isConstant() && GV->hasDefinitiveInitializer()) - if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) { + if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), + M->getModule()->getDataLayout())) { IRBuilder<> Builder(M); Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), M->getDestAlignment(), false); diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp index 69fd8b163a07..3d047a193267 100644 --- a/lib/Transforms/Scalar/MergeICmps.cpp +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -1,9 +1,8 @@ //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -11,29 +10,54 @@ // later typically inlined as a chain of efficient hardware comparisons). This // typically benefits c++ member or nonmember operator==(). // -// The basic idea is to replace a larger chain of integer comparisons loaded -// from contiguous memory locations into a smaller chain of such integer +// The basic idea is to replace a longer chain of integer comparisons loaded +// from contiguous memory locations into a shorter chain of larger integer // comparisons. Benefits are double: // - There are less jumps, and therefore less opportunities for mispredictions // and I-cache misses. // - Code size is smaller, both because jumps are removed and because the // encoding of a 2*n byte compare is smaller than that of two n-byte // compares. - +// +// Example: +// +// struct S { +// int a; +// char b; +// char c; +// uint16_t d; +// bool operator==(const S& o) const { +// return a == o.a && b == o.b && c == o.c && d == o.d; +// } +// }; +// +// Is optimized as : +// +// bool S::operator==(const S& o) const { +// return memcmp(this, &o, 8) == 0; +// } +// +// Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp. +// //===----------------------------------------------------------------------===// -#include <algorithm> -#include <numeric> -#include <utility> -#include <vector> +#include "llvm/Transforms/Scalar/MergeICmps.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include <algorithm> +#include <numeric> +#include <utility> +#include <vector> using namespace llvm; @@ -50,76 +74,109 @@ static bool isSimpleLoadOrStore(const Instruction *I) { return false; } -// A BCE atom. +// A BCE atom "Binary Compare Expression Atom" represents an integer load +// that is a constant offset from a base value, e.g. `a` or `o.c` in the example +// at the top. struct BCEAtom { - BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} - - const Value *Base() const { return GEP ? GEP->getPointerOperand() : nullptr; } - + BCEAtom() = default; + BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset) + : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {} + + BCEAtom(const BCEAtom &) = delete; + BCEAtom &operator=(const BCEAtom &) = delete; + + BCEAtom(BCEAtom &&that) = default; + BCEAtom &operator=(BCEAtom &&that) { + if (this == &that) + return *this; + GEP = that.GEP; + LoadI = that.LoadI; + BaseId = that.BaseId; + Offset = std::move(that.Offset); + return *this; + } + + // We want to order BCEAtoms by (Base, Offset). However we cannot use + // the pointer values for Base because these are non-deterministic. + // To make sure that the sort order is stable, we first assign to each atom + // base value an index based on its order of appearance in the chain of + // comparisons. We call this index `BaseOrdering`. For example, for: + // b[3] == c[2] && a[1] == d[1] && b[4] == c[3] + // | block 1 | | block 2 | | block 3 | + // b gets assigned index 0 and a index 1, because b appears as LHS in block 1, + // which is before block 2. + // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable. bool operator<(const BCEAtom &O) const { - assert(Base() && "invalid atom"); - assert(O.Base() && "invalid atom"); - // Just ordering by (Base(), Offset) is sufficient. However because this - // means that the ordering will depend on the addresses of the base - // values, which are not reproducible from run to run. To guarantee - // stability, we use the names of the values if they exist; we sort by: - // (Base.getName(), Base(), Offset). - const int NameCmp = Base()->getName().compare(O.Base()->getName()); - if (NameCmp == 0) { - if (Base() == O.Base()) { - return Offset.slt(O.Offset); - } - return Base() < O.Base(); - } - return NameCmp < 0; + return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset); } - GetElementPtrInst *GEP; - LoadInst *LoadI; + GetElementPtrInst *GEP = nullptr; + LoadInst *LoadI = nullptr; + unsigned BaseId = 0; APInt Offset; }; +// A class that assigns increasing ids to values in the order in which they are +// seen. See comment in `BCEAtom::operator<()``. +class BaseIdentifier { +public: + // Returns the id for value `Base`, after assigning one if `Base` has not been + // seen before. + int getBaseId(const Value *Base) { + assert(Base && "invalid base"); + const auto Insertion = BaseToIndex.try_emplace(Base, Order); + if (Insertion.second) + ++Order; + return Insertion.first->second; + } + +private: + unsigned Order = 1; + DenseMap<const Value*, int> BaseToIndex; +}; + // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val) { - BCEAtom Result; - if (auto *const LoadI = dyn_cast<LoadInst>(Val)) { - LLVM_DEBUG(dbgs() << "load\n"); - if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - // Do not optimize atomic loads to non-atomic memcmp - if (!LoadI->isSimple()) { - LLVM_DEBUG(dbgs() << "volatile or atomic\n"); - return {}; - } - Value *const Addr = LoadI->getOperand(0); - if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { - LLVM_DEBUG(dbgs() << "GEP\n"); - if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - const auto &DL = GEP->getModule()->getDataLayout(); - if (!isDereferenceablePointer(GEP, DL)) { - LLVM_DEBUG(dbgs() << "not dereferenceable\n"); - // We need to make sure that we can do comparison in any order, so we - // require memory to be unconditionnally dereferencable. - return {}; - } - Result.Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); - if (GEP->accumulateConstantOffset(DL, Result.Offset)) { - Result.GEP = GEP; - Result.LoadI = LoadI; - } - } +BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { + auto *const LoadI = dyn_cast<LoadInst>(Val); + if (!LoadI) + return {}; + LLVM_DEBUG(dbgs() << "load\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + // Do not optimize atomic loads to non-atomic memcmp + if (!LoadI->isSimple()) { + LLVM_DEBUG(dbgs() << "volatile or atomic\n"); + return {}; } - return Result; + Value *const Addr = LoadI->getOperand(0); + auto *const GEP = dyn_cast<GetElementPtrInst>(Addr); + if (!GEP) + return {}; + LLVM_DEBUG(dbgs() << "GEP\n"); + if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + const auto &DL = GEP->getModule()->getDataLayout(); + if (!isDereferenceablePointer(GEP, LoadI->getType(), DL)) { + LLVM_DEBUG(dbgs() << "not dereferenceable\n"); + // We need to make sure that we can do comparison in any order, so we + // require memory to be unconditionnally dereferencable. + return {}; + } + APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (!GEP->accumulateConstantOffset(DL, Offset)) + return {}; + return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()), + Offset); } -// A basic block with a comparison between two BCE atoms. +// A basic block with a comparison between two BCE atoms, e.g. `a == o.a` in the +// example at the top. // The block might do extra work besides the atom comparison, in which case // doesOtherWork() returns true. Under some conditions, the block can be // split into the atom comparison part and the "other work" part @@ -133,13 +190,11 @@ class BCECmpBlock { BCECmpBlock() {} BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits) - : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) { + : Lhs_(std::move(L)), Rhs_(std::move(R)), SizeBits_(SizeBits) { if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); } - bool IsValid() const { - return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; - } + bool IsValid() const { return Lhs_.BaseId != 0 && Rhs_.BaseId != 0; } // Assert the block is consistent: If valid, it should also have // non-null members besides Lhs_ and Rhs_. @@ -160,19 +215,19 @@ class BCECmpBlock { // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp // instructions in the block. - bool canSplit(AliasAnalysis *AA) const; + bool canSplit(AliasAnalysis &AA) const; // Return true if this all the relevant instructions in the BCE-cmp-block can // be sunk below this instruction. By doing this, we know we can separate the // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the // block. bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &, - AliasAnalysis *AA) const; + AliasAnalysis &AA) const; // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block // instructions. Split the old block and move all non-BCE-cmp-insts into the // new parent block. - void split(BasicBlock *NewParent, AliasAnalysis *AA) const; + void split(BasicBlock *NewParent, AliasAnalysis &AA) const; // The basic block where this comparison happens. BasicBlock *BB = nullptr; @@ -191,7 +246,7 @@ private: bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, DenseSet<Instruction *> &BlockInsts, - AliasAnalysis *AA) const { + AliasAnalysis &AA) const { // If this instruction has side effects and its in middle of the BCE cmp block // instructions, then bail for now. if (Inst->mayHaveSideEffects()) { @@ -201,9 +256,9 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, // Disallow stores that might alias the BCE operands MemoryLocation LLoc = MemoryLocation::get(Lhs_.LoadI); MemoryLocation RLoc = MemoryLocation::get(Rhs_.LoadI); - if (isModSet(AA->getModRefInfo(Inst, LLoc)) || - isModSet(AA->getModRefInfo(Inst, RLoc))) - return false; + if (isModSet(AA.getModRefInfo(Inst, LLoc)) || + isModSet(AA.getModRefInfo(Inst, RLoc))) + return false; } // Make sure this instruction does not use any of the BCE cmp block // instructions as operand. @@ -214,7 +269,7 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, return true; } -void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { +void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); llvm::SmallVector<Instruction *, 4> OtherInsts; @@ -234,7 +289,7 @@ void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { } } -bool BCECmpBlock::canSplit(AliasAnalysis *AA) const { +bool BCECmpBlock::canSplit(AliasAnalysis &AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); for (Instruction &Inst : *BB) { @@ -265,7 +320,8 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate) { + const ICmpInst::Predicate ExpectedPredicate, + BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -275,25 +331,27 @@ BCECmpBlock visitICmp(const ICmpInst *const CmpI, LLVM_DEBUG(dbgs() << "cmp has several uses\n"); return {}; } - if (CmpI->getPredicate() == ExpectedPredicate) { - LLVM_DEBUG(dbgs() << "cmp " - << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") - << "\n"); - auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); - if (!Lhs.Base()) return {}; - auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); - if (!Rhs.Base()) return {}; - const auto &DL = CmpI->getModule()->getDataLayout(); - return BCECmpBlock(std::move(Lhs), std::move(Rhs), - DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); - } - return {}; + if (CmpI->getPredicate() != ExpectedPredicate) + return {}; + LLVM_DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); + auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId); + if (!Lhs.BaseId) + return {}; + auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); + if (!Rhs.BaseId) + return {}; + const auto &DL = CmpI->getModule()->getDataLayout(); + return BCECmpBlock(std::move(Lhs), std::move(Rhs), + DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); } // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, - const BasicBlock *const PhiBlock) { + const BasicBlock *const PhiBlock, + BaseIdentifier &BaseId) { if (Block->empty()) return {}; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); if (!BranchI) return {}; @@ -306,7 +364,7 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, auto *const CmpI = dyn_cast<ICmpInst>(Val); if (!CmpI) return {}; LLVM_DEBUG(dbgs() << "icmp\n"); - auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); + auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ, BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -323,7 +381,8 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); auto Result = visitICmp( - CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE); + CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -332,47 +391,41 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, } static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, - BCECmpBlock &Comparison) { + BCECmpBlock &&Comparison) { LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName() << "': Found cmp of " << Comparison.SizeBits() - << " bits between " << Comparison.Lhs().Base() << " + " + << " bits between " << Comparison.Lhs().BaseId << " + " << Comparison.Lhs().Offset << " and " - << Comparison.Rhs().Base() << " + " + << Comparison.Rhs().BaseId << " + " << Comparison.Rhs().Offset << "\n"); LLVM_DEBUG(dbgs() << "\n"); - Comparisons.push_back(Comparison); + Comparisons.push_back(std::move(Comparison)); } // A chain of comparisons. class BCECmpChain { public: - BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, - AliasAnalysis *AA); + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, + AliasAnalysis &AA); - int size() const { return Comparisons_.size(); } + int size() const { return Comparisons_.size(); } #ifdef MERGEICMPS_DOT_ON void dump() const; #endif // MERGEICMPS_DOT_ON - bool simplify(const TargetLibraryInfo *const TLI, AliasAnalysis *AA); + bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU); - private: +private: static bool IsContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { - return First.Lhs().Base() == Second.Lhs().Base() && - First.Rhs().Base() == Second.Rhs().Base() && + return First.Lhs().BaseId == Second.Lhs().BaseId && + First.Rhs().BaseId == Second.Rhs().BaseId && First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; } - // Merges the given comparison blocks into one memcmp block and update - // branches. Comparisons are assumed to be continguous. If NextBBInChain is - // null, the merged block will link to the phi block. - void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, - BasicBlock *const NextBBInChain, PHINode &Phi, - const TargetLibraryInfo *const TLI, AliasAnalysis *AA); - PHINode &Phi_; std::vector<BCECmpBlock> Comparisons_; // The original entry block (before sorting); @@ -380,16 +433,17 @@ class BCECmpChain { }; BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, - AliasAnalysis *AA) + AliasAnalysis &AA) : Phi_(Phi) { assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. std::vector<BCECmpBlock> Comparisons; + BaseIdentifier BaseId; for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) { BasicBlock *const Block = Blocks[BlockIdx]; assert(Block && "invalid block"); BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), - Block, Phi.getParent()); + Block, Phi.getParent(), BaseId); Comparison.BB = Block; if (!Comparison.IsValid()) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); @@ -411,13 +465,13 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, // chain before sorting. Unless we can abort the chain at this point // and start anew. // - // NOTE: we only handle block with single predecessor for now. + // NOTE: we only handle blocks a with single predecessor for now. if (Comparison.canSplit(AA)) { LLVM_DEBUG(dbgs() << "Split initial block '" << Comparison.BB->getName() << "' that does extra work besides compare\n"); Comparison.RequireSplit = true; - enqueueBlock(Comparisons, Comparison); + enqueueBlock(Comparisons, std::move(Comparison)); } else { LLVM_DEBUG(dbgs() << "ignoring initial block '" << Comparison.BB->getName() @@ -450,7 +504,7 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, // We could still merge bb1 and bb2 though. return; } - enqueueBlock(Comparisons, Comparison); + enqueueBlock(Comparisons, std::move(Comparison)); } // It is possible we have no suitable comparison to merge. @@ -466,9 +520,11 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, #endif // MERGEICMPS_DOT_ON // Reorder blocks by LHS. We can do that without changing the // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_, [](const BCECmpBlock &a, const BCECmpBlock &b) { - return a.Lhs() < b.Lhs(); - }); + llvm::sort(Comparisons_, + [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { + return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < + std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); + }); #ifdef MERGEICMPS_DOT_ON errs() << "AFTER REORDERING:\n\n"; dump(); @@ -498,162 +554,205 @@ void BCECmpChain::dump() const { } #endif // MERGEICMPS_DOT_ON -bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { - // First pass to check if there is at least one merge. If not, we don't do - // anything and we keep analysis passes intact. - { - bool AtLeastOneMerged = false; - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { - AtLeastOneMerged = true; - break; +namespace { + +// A class to compute the name of a set of merged basic blocks. +// This is optimized for the common case of no block names. +class MergedBlockName { + // Storage for the uncommon case of several named blocks. + SmallString<16> Scratch; + +public: + explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons) + : Name(makeName(Comparisons)) {} + const StringRef Name; + +private: + StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) { + assert(!Comparisons.empty() && "no basic block"); + // Fast path: only one block, or no names at all. + if (Comparisons.size() == 1) + return Comparisons[0].BB->getName(); + const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0, + [](int i, const BCECmpBlock &Cmp) { + return i + Cmp.BB->getName().size(); + }); + if (size == 0) + return StringRef("", 0); + + // Slow path: at least two blocks, at least one block with a name. + Scratch.clear(); + // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for + // separators. + Scratch.reserve(size + Comparisons.size() - 1); + const auto append = [this](StringRef str) { + Scratch.append(str.begin(), str.end()); + }; + append(Comparisons[0].BB->getName()); + for (int I = 1, E = Comparisons.size(); I < E; ++I) { + const BasicBlock *const BB = Comparisons[I].BB; + if (!BB->getName().empty()) { + append("+"); + append(BB->getName()); } } - if (!AtLeastOneMerged) return false; + return StringRef(Scratch); } +}; +} // namespace + +// Merges the given contiguous comparison blocks into one memcmp block. +static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const InsertBefore, + BasicBlock *const NextCmpBlock, + PHINode &Phi, const TargetLibraryInfo &TLI, + AliasAnalysis &AA, DomTreeUpdater &DTU) { + assert(!Comparisons.empty() && "merging zero comparisons"); + LLVMContext &Context = NextCmpBlock->getContext(); + const BCECmpBlock &FirstCmp = Comparisons[0]; + + // Create a new cmp block before next cmp block. + BasicBlock *const BB = + BasicBlock::Create(Context, MergedBlockName(Comparisons).Name, + NextCmpBlock->getParent(), InsertBefore); + IRBuilder<> Builder(BB); + // Add the GEPs from the first BCECmpBlock. + Value *const Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone()); + Value *const Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone()); + + Value *IsEqual = nullptr; + LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> " + << BB->getName() << "\n"); + if (Comparisons.size() == 1) { + LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); + Value *const LhsLoad = + Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs); + Value *const RhsLoad = + Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); + // There are no blocks to merge, just do the comparison. + IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + } else { + // If there is one block that requires splitting, we do it now, i.e. + // just before we know we will collapse the chain. The instructions + // can be executed before any of the instructions in the chain. + const auto ToSplit = + std::find_if(Comparisons.begin(), Comparisons.end(), + [](const BCECmpBlock &B) { return B.RequireSplit; }); + if (ToSplit != Comparisons.end()) { + LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n"); + ToSplit->split(BB, AA); + } - // Remove phi references to comparison blocks, they will be rebuilt as we - // merge the blocks. - for (const auto &Comparison : Comparisons_) { - Phi_.removeIncomingValue(Comparison.BB, false); - } + const unsigned TotalSizeBits = std::accumulate( + Comparisons.begin(), Comparisons.end(), 0u, + [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); }); - // If entry block is part of the chain, we need to make the first block - // of the chain the new entry block of the function. - BasicBlock *Entry = &Comparisons_[0].BB->getParent()->getEntryBlock(); - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (Entry == Comparisons_[I].BB) { - BasicBlock *NEntryBB = BasicBlock::Create(Entry->getContext(), "", - Entry->getParent(), Entry); - BranchInst::Create(Entry, NEntryBB); - break; - } + // Create memcmp() == 0. + const auto &DL = Phi.getModule()->getDataLayout(); + Value *const MemCmpCall = emitMemCmp( + Lhs, Rhs, + ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder, + DL, &TLI); + IsEqual = Builder.CreateICmpEQ( + MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); } - // Point the predecessors of the chain to the first comparison block (which is - // the new entry point) and update the entry block of the chain. - if (EntryBlock_ != Comparisons_[0].BB) { - EntryBlock_->replaceAllUsesWith(Comparisons_[0].BB); - EntryBlock_ = Comparisons_[0].BB; + BasicBlock *const PhiBB = Phi.getParent(); + // Add a branch to the next basic block in the chain. + if (NextCmpBlock == PhiBB) { + // Continue to phi, passing it the comparison result. + Builder.CreateBr(PhiBB); + Phi.addIncoming(IsEqual, BB); + DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}}); + } else { + // Continue to next block if equal, exit to phi else. + Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock}, + {DominatorTree::Insert, BB, PhiBB}}); } + return BB; +} + +bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU) { + assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain"); + // First pass to check if there is at least one merge. If not, we don't do + // anything and we keep analysis passes intact. + const auto AtLeastOneMerged = [this]() { + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) + return true; + } + return false; + }; + if (!AtLeastOneMerged()) + return false; - // Effectively merge blocks. + LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " + << EntryBlock_->getName() << "\n"); + + // Effectively merge blocks. We go in the reverse direction from the phi block + // so that the next block is always available to branch to. + const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num, + BasicBlock *InsertBefore, + BasicBlock *Next) { + return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num), + InsertBefore, Next, Phi_, TLI, AA, DTU); + }; int NumMerged = 1; - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + BasicBlock *NextCmpBlock = Phi_.getParent(); + for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) { + if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) { + LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName() + << " into " << Comparisons_[I + 1].BB->getName() + << "\n"); ++NumMerged; } else { - // Merge all previous comparisons and start a new merge block. - mergeComparisons( - makeArrayRef(Comparisons_).slice(I - NumMerged, NumMerged), - Comparisons_[I].BB, Phi_, TLI, AA); + NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock); NumMerged = 1; } } - mergeComparisons(makeArrayRef(Comparisons_) - .slice(Comparisons_.size() - NumMerged, NumMerged), - nullptr, Phi_, TLI, AA); - - return true; -} - -void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, - BasicBlock *const NextBBInChain, - PHINode &Phi, - const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { - assert(!Comparisons.empty()); - const auto &FirstComparison = *Comparisons.begin(); - BasicBlock *const BB = FirstComparison.BB; - LLVMContext &Context = BB->getContext(); - - if (Comparisons.size() >= 2) { - // If there is one block that requires splitting, we do it now, i.e. - // just before we know we will collapse the chain. The instructions - // can be executed before any of the instructions in the chain. - auto C = std::find_if(Comparisons.begin(), Comparisons.end(), - [](const BCECmpBlock &B) { return B.RequireSplit; }); - if (C != Comparisons.end()) - C->split(EntryBlock_, AA); - - LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); - const auto TotalSize = - std::accumulate(Comparisons.begin(), Comparisons.end(), 0, - [](int Size, const BCECmpBlock &C) { - return Size + C.SizeBits(); - }) / - 8; - - // Incoming edges do not need to be updated, and both GEPs are already - // computing the right address, we just need to: - // - replace the two loads and the icmp with the memcmp - // - update the branch - // - update the incoming values in the phi. - FirstComparison.BranchI->eraseFromParent(); - FirstComparison.CmpI->eraseFromParent(); - FirstComparison.Lhs().LoadI->eraseFromParent(); - FirstComparison.Rhs().LoadI->eraseFromParent(); - - IRBuilder<> Builder(BB); - const auto &DL = Phi.getModule()->getDataLayout(); - Value *const MemCmpCall = emitMemCmp( - FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, - ConstantInt::get(DL.getIntPtrType(Context), TotalSize), - Builder, DL, TLI); - Value *const MemCmpIsZero = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); + // Insert the entry block for the new chain before the old entry block. + // If the old entry block was the function entry, this ensures that the new + // entry can become the function entry. + NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock); + + // Replace the original cmp chain with the new cmp chain by pointing all + // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp + // blocks in the old chain unreachable. + while (!pred_empty(EntryBlock_)) { + BasicBlock* const Pred = *pred_begin(EntryBlock_); + LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName() + << "\n"); + Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock); + DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_}, + {DominatorTree::Insert, Pred, NextCmpBlock}}); + } - // Add a branch to the next basic block in the chain. - if (NextBBInChain) { - Builder.CreateCondBr(MemCmpIsZero, NextBBInChain, Phi.getParent()); - Phi.addIncoming(ConstantInt::getFalse(Context), BB); - } else { - Builder.CreateBr(Phi.getParent()); - Phi.addIncoming(MemCmpIsZero, BB); - } + // If the old cmp chain was the function entry, we need to update the function + // entry. + const bool ChainEntryIsFnEntry = + (EntryBlock_ == &EntryBlock_->getParent()->getEntryBlock()); + if (ChainEntryIsFnEntry && DTU.hasDomTree()) { + LLVM_DEBUG(dbgs() << "Changing function entry from " + << EntryBlock_->getName() << " to " + << NextCmpBlock->getName() << "\n"); + DTU.getDomTree().setNewRoot(NextCmpBlock); + DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}}); + } + EntryBlock_ = nullptr; - // Delete merged blocks. - for (size_t I = 1; I < Comparisons.size(); ++I) { - BasicBlock *CBB = Comparisons[I].BB; - CBB->replaceAllUsesWith(BB); - CBB->eraseFromParent(); - } - } else { - assert(Comparisons.size() == 1); - // There are no blocks to merge, but we still need to update the branches. - LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); - if (NextBBInChain) { - if (FirstComparison.BranchI->isConditional()) { - LLVM_DEBUG(dbgs() << "conditional -> conditional\n"); - // Just update the "true" target, the "false" target should already be - // the phi block. - assert(FirstComparison.BranchI->getSuccessor(1) == Phi.getParent()); - FirstComparison.BranchI->setSuccessor(0, NextBBInChain); - Phi.addIncoming(ConstantInt::getFalse(Context), BB); - } else { - LLVM_DEBUG(dbgs() << "unconditional -> conditional\n"); - // Replace the unconditional branch by a conditional one. - FirstComparison.BranchI->eraseFromParent(); - IRBuilder<> Builder(BB); - Builder.CreateCondBr(FirstComparison.CmpI, NextBBInChain, - Phi.getParent()); - Phi.addIncoming(FirstComparison.CmpI, BB); - } - } else { - if (FirstComparison.BranchI->isConditional()) { - LLVM_DEBUG(dbgs() << "conditional -> unconditional\n"); - // Replace the conditional branch by an unconditional one. - FirstComparison.BranchI->eraseFromParent(); - IRBuilder<> Builder(BB); - Builder.CreateBr(Phi.getParent()); - Phi.addIncoming(FirstComparison.CmpI, BB); - } else { - LLVM_DEBUG(dbgs() << "unconditional -> unconditional\n"); - Phi.addIncoming(FirstComparison.CmpI, BB); - } - } + // Delete merged blocks. This also removes incoming values in phi. + SmallVector<BasicBlock *, 16> DeadBlocks; + for (auto &Cmp : Comparisons_) { + LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n"); + DeadBlocks.push_back(Cmp.BB); } + DeleteDeadBlocks(DeadBlocks, &DTU); + + Comparisons_.clear(); + return true; } std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, @@ -691,8 +790,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { +bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -757,24 +856,54 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, return false; } - return CmpChain.simplify(TLI, AA); + return CmpChain.simplify(TLI, AA, DTU); } -class MergeICmps : public FunctionPass { - public: +static bool runImpl(Function &F, const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, AliasAnalysis &AA, + DominatorTree *DT) { + LLVM_DEBUG(dbgs() << "MergeICmpsLegacyPass: " << F.getName() << "\n"); + + // We only try merging comparisons if the target wants to expand memcmp later. + // The rationale is to avoid turning small chains into memcmp calls. + if (!TTI.enableMemCmpExpansion(F.hasOptSize(), true)) + return false; + + // If we don't have memcmp avaiable we can't emit calls to it. + if (!TLI.has(LibFunc_memcmp)) + return false; + + DomTreeUpdater DTU(DT, /*PostDominatorTree*/ nullptr, + DomTreeUpdater::UpdateStrategy::Eager); + + bool MadeChange = false; + + for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { + // A Phi operation is always first in a basic block. + if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) + MadeChange |= processPhi(*Phi, TLI, AA, DTU); + } + + return MadeChange; +} + +class MergeICmpsLegacyPass : public FunctionPass { +public: static char ID; - MergeICmps() : FunctionPass(ID) { - initializeMergeICmpsPass(*PassRegistry::getPassRegistry()); + MergeICmpsLegacyPass() : FunctionPass(ID) { + initializeMergeICmpsLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto PA = runImpl(F, &TLI, &TTI, AA); - return !PA.areAllPreserved(); + // MergeICmps does not need the DominatorTree, but we update it if it's + // already available. + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + return runImpl(F, TLI, TTI, AA, DTWP ? &DTWP->getDomTree() : nullptr); } private: @@ -782,46 +911,35 @@ class MergeICmps : public FunctionPass { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); } - - PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, AliasAnalysis *AA); }; -PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, - AliasAnalysis *AA) { - LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); - - // We only try merging comparisons if the target wants to expand memcmp later. - // The rationale is to avoid turning small chains into memcmp calls. - if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); - - // If we don't have memcmp avaiable we can't emit calls to it. - if (!TLI->has(LibFunc_memcmp)) - return PreservedAnalyses::all(); - - bool MadeChange = false; - - for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { - // A Phi operation is always first in a basic block. - if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) - MadeChange |= processPhi(*Phi, TLI, AA); - } - - if (MadeChange) return PreservedAnalyses::none(); - return PreservedAnalyses::all(); -} +} // namespace -} // namespace - -char MergeICmps::ID = 0; -INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", +char MergeICmpsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MergeICmpsLegacyPass, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(MergeICmps, "mergeicmps", +INITIALIZE_PASS_END(MergeICmpsLegacyPass, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) -Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); } +Pass *llvm::createMergeICmpsLegacyPass() { return new MergeICmpsLegacyPass(); } + +PreservedAnalyses MergeICmpsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + const bool MadeChanges = runImpl(F, TLI, TTI, AA, DT); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + return PA; +} diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index ee21feca8d2c..30645f4400e3 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -1,9 +1,8 @@ //===- MergedLoadStoreMotion.cpp - merge and hoist/sink load/stores -------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index 7106ea216ad6..94436b55752a 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -1,9 +1,8 @@ //===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -427,8 +426,8 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, RHS = Builder.CreateMul( RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize)); } - GetElementPtrInst *NewGEP = - cast<GetElementPtrInst>(Builder.CreateGEP(Candidate, RHS)); + GetElementPtrInst *NewGEP = cast<GetElementPtrInst>( + Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS)); NewGEP->setIsInBounds(GEP->isInBounds()); NewGEP->takeName(GEP); return NewGEP; diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 7cbb0fe70f82..08ac2b666fce 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -1,9 +1,8 @@ //===- NewGVN.cpp - Global Value Numbering Pass ---------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -1167,9 +1166,9 @@ const Expression *NewGVN::createExpression(Instruction *I) const { SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; - } else if (auto *BI = dyn_cast<BitCastInst>(I)) { + } else if (auto *CI = dyn_cast<CastInst>(I)) { Value *V = - SimplifyCastInst(BI->getOpcode(), BI->getOperand(0), BI->getType(), SQ); + SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ); if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) return SimplifiedE; } else if (isa<GetElementPtrInst>(I)) { @@ -1815,39 +1814,13 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, const Expression * NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const { if (auto *EI = dyn_cast<ExtractValueInst>(I)) { - auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); - if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { - unsigned Opcode = 0; - // EI might be an extract from one of our recognised intrinsics. If it - // is we'll synthesize a semantically equivalent expression instead on - // an extract value expression. - switch (II->getIntrinsicID()) { - case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_with_overflow: - Opcode = Instruction::Add; - break; - case Intrinsic::ssub_with_overflow: - case Intrinsic::usub_with_overflow: - Opcode = Instruction::Sub; - break; - case Intrinsic::smul_with_overflow: - case Intrinsic::umul_with_overflow: - Opcode = Instruction::Mul; - break; - default: - break; - } - - if (Opcode != 0) { - // Intrinsic recognized. Grab its args to finish building the - // expression. - assert(II->getNumArgOperands() == 2 && - "Expect two args for recognised intrinsics."); - return createBinaryExpression(Opcode, EI->getType(), - II->getArgOperand(0), - II->getArgOperand(1), I); - } - } + auto *WO = dyn_cast<WithOverflowInst>(EI->getAggregateOperand()); + if (WO && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) + // EI is an extract from one of our with.overflow intrinsics. Synthesize + // a semantically equivalent expression instead of an extract value + // expression. + return createBinaryExpression(WO->getBinaryOp(), EI->getType(), + WO->getLHS(), WO->getRHS(), I); } return createAggregateValueExpression(I); @@ -2011,12 +1984,14 @@ NewGVN::performSymbolicEvaluation(Value *V, E = performSymbolicLoadEvaluation(I); break; case Instruction::BitCast: + case Instruction::AddrSpaceCast: E = createExpression(I); break; case Instruction::ICmp: case Instruction::FCmp: E = performSymbolicCmpEvaluation(I); break; + case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -2122,7 +2097,7 @@ void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) const { if (auto *PBranch = dyn_cast<PredicateBranch>(PB)) PredicateToUsers[PBranch->Condition].insert(I); - else if (auto *PAssume = dyn_cast<PredicateBranch>(PB)) + else if (auto *PAssume = dyn_cast<PredicateAssume>(PB)) PredicateToUsers[PAssume->Condition].insert(I); } @@ -2524,9 +2499,6 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { // For switches, propagate the case values into the case // destinations. - // Remember how many outgoing edges there are to every successor. - SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; - Value *SwitchCond = SI->getCondition(); Value *CondEvaluated = findConditionEquivalence(SwitchCond); // See if we were able to turn this switch statement into a constant. @@ -2547,7 +2519,6 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { } else { for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { BasicBlock *TargetBlock = SI->getSuccessor(i); - ++SwitchEdges[TargetBlock]; updateReachableEdge(B, TargetBlock); } } @@ -3503,7 +3474,7 @@ bool NewGVN::runGVN() { "BB containing ToErase deleted unexpectedly!"); ToErase->eraseFromParent(); } - Changed |= !InstructionsToErase.empty(); + Changed |= !InstructionsToErase.empty(); // Delete all unreachable blocks. auto UnreachableBlockPred = [&](const BasicBlock &BB) { diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 05ea9144f66c..039123218544 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -1,9 +1,8 @@ //===--- PartiallyInlineLibCalls.cpp - Partially inline libcalls ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index fd2eb85fd7bf..b544f0a39ea8 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -1,9 +1,8 @@ //===- PlaceSafepoints.cpp - Place GC Safepoints --------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -56,7 +55,6 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LegacyPassManager.h" @@ -179,19 +177,18 @@ struct PlaceSafepoints : public FunctionPass { // callers job. static void InsertSafepointPoll(Instruction *InsertBefore, - std::vector<CallSite> &ParsePointsNeeded /*rval*/, + std::vector<CallBase *> &ParsePointsNeeded /*rval*/, const TargetLibraryInfo &TLI); -static bool needsStatepoint(const CallSite &CS, const TargetLibraryInfo &TLI) { - if (callsGCLeafFunction(CS, TLI)) +static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) { + if (callsGCLeafFunction(Call, TLI)) return false; - if (CS.isCall()) { - CallInst *call = cast<CallInst>(CS.getInstruction()); - if (call->isInlineAsm()) + if (auto *CI = dyn_cast<CallInst>(Call)) { + if (CI->isInlineAsm()) return false; } - return !(isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)); + return !(isStatepoint(Call) || isGCRelocate(Call) || isGCResult(Call)); } /// Returns true if this loop is known to contain a call safepoint which @@ -217,14 +214,14 @@ static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, BasicBlock *Current = Pred; while (true) { for (Instruction &I : *Current) { - if (auto CS = CallSite(&I)) + if (auto *Call = dyn_cast<CallBase>(&I)) // Note: Technically, needing a safepoint isn't quite the right // condition here. We should instead be checking if the target method // has an // unconditional poll. In practice, this is only a theoretical concern // since we don't have any methods with conditional-only safepoint // polls. - if (needsStatepoint(CS, TLI)) + if (needsStatepoint(Call, TLI)) return true; } @@ -360,9 +357,8 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { /// Returns true if an entry safepoint is not required before this callsite in /// the caller function. -static bool doesNotRequireEntrySafepointBefore(const CallSite &CS) { - Instruction *Inst = CS.getInstruction(); - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { +static bool doesNotRequireEntrySafepointBefore(CallBase *Call) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call)) { switch (II->getIntrinsicID()) { case Intrinsic::experimental_gc_statepoint: case Intrinsic::experimental_patchpoint_void: @@ -424,8 +420,8 @@ static Instruction *findLocationForEntrySafepoint(Function &F, // which can grow the stack by an unbounded amount. This isn't required // for GC semantics per se, but is a common requirement for languages // which detect stack overflow via guard pages and then throw exceptions. - if (auto CS = CallSite(Cursor)) { - if (doesNotRequireEntrySafepointBefore(CS)) + if (auto *Call = dyn_cast<CallBase>(Cursor)) { + if (doesNotRequireEntrySafepointBefore(Call)) continue; break; } @@ -500,7 +496,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { DT.recalculate(F); SmallVector<Instruction *, 16> PollsNeeded; - std::vector<CallSite> ParsePointNeeded; + std::vector<CallBase *> ParsePointNeeded; if (enableBackedgeSafepoints(F)) { // Construct a pass manager to run the LoopPass backedge logic. We @@ -589,7 +585,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { // Now that we've identified all the needed safepoint poll locations, insert // safepoint polls themselves. for (Instruction *PollLocation : PollsNeeded) { - std::vector<CallSite> RuntimeCalls; + std::vector<CallBase *> RuntimeCalls; InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); ParsePointNeeded.insert(ParsePointNeeded.end(), RuntimeCalls.begin(), RuntimeCalls.end()); @@ -622,7 +618,7 @@ INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints", static void InsertSafepointPoll(Instruction *InsertBefore, - std::vector<CallSite> &ParsePointsNeeded /*rval*/, + std::vector<CallBase *> &ParsePointsNeeded /*rval*/, const TargetLibraryInfo &TLI) { BasicBlock *OrigBB = InsertBefore->getParent(); Module *M = InsertBefore->getModule(); @@ -687,7 +683,7 @@ InsertSafepointPoll(Instruction *InsertBefore, // These are likely runtime calls. Should we assert that via calling // convention or something? - ParsePointsNeeded.push_back(CallSite(CI)); + ParsePointsNeeded.push_back(CI); } assert(ParsePointsNeeded.size() <= Calls.size()); } diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index cb893eab1654..fa8c9e2a5fe4 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -1,9 +1,8 @@ //===- Reassociate.cpp - Reassociate binary expressions -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -267,12 +266,16 @@ static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, /// Replace 0-X with X*-1. static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { + assert((isa<UnaryOperator>(Neg) || isa<BinaryOperator>(Neg)) && + "Expected a Negate!"); + // FIXME: It's not safe to lower a unary FNeg into a FMul by -1.0. + unsigned OpNo = isa<BinaryOperator>(Neg) ? 1 : 0; Type *Ty = Neg->getType(); Constant *NegOne = Ty->isIntOrIntVectorTy() ? ConstantInt::getAllOnesValue(Ty) : ConstantFP::get(Ty, -1.0); - BinaryOperator *Res = CreateMul(Neg->getOperand(1), NegOne, "", Neg, Neg); - Neg->setOperand(1, Constant::getNullValue(Ty)); // Drop use of op. + BinaryOperator *Res = CreateMul(Neg->getOperand(OpNo), NegOne, "", Neg, Neg); + Neg->setOperand(OpNo, Constant::getNullValue(Ty)); // Drop use of op. Res->takeName(Neg); Neg->replaceAllUsesWith(Res); Res->setDebugLoc(Neg->getDebugLoc()); @@ -445,8 +448,10 @@ using RepeatedValue = std::pair<Value*, APInt>; /// that have all uses inside the expression (i.e. only used by non-leaf nodes /// of the expression) if it can turn them into binary operators of the right /// type and thus make the expression bigger. -static bool LinearizeExprTree(BinaryOperator *I, +static bool LinearizeExprTree(Instruction *I, SmallVectorImpl<RepeatedValue> &Ops) { + assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) && + "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); @@ -463,7 +468,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // with their weights, representing a certain number of paths to the operator. // If an operator occurs in the worklist multiple times then we found multiple // ways to get to it. - SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight) + SmallVector<std::pair<Instruction*, APInt>, 8> Worklist; // (Op, Weight) Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); bool Changed = false; @@ -490,10 +495,10 @@ static bool LinearizeExprTree(BinaryOperator *I, SmallPtrSet<Value *, 8> Visited; // For sanity checking the iteration scheme. #endif while (!Worklist.empty()) { - std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val(); + std::pair<Instruction*, APInt> P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. - for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. + for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); APInt Weight = P.second; // Number of paths to this operand. LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); @@ -573,14 +578,14 @@ static bool LinearizeExprTree(BinaryOperator *I, // If this is a multiply expression, turn any internal negations into // multiplies by -1 so they can be reassociated. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) - if ((Opcode == Instruction::Mul && match(BO, m_Neg(m_Value()))) || - (Opcode == Instruction::FMul && match(BO, m_FNeg(m_Value())))) { + if (Instruction *Tmp = dyn_cast<Instruction>(Op)) + if ((Opcode == Instruction::Mul && match(Tmp, m_Neg(m_Value()))) || + (Opcode == Instruction::FMul && match(Tmp, m_FNeg(m_Value())))) { LLVM_DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); - BO = LowerNegateToMultiply(BO); - LLVM_DEBUG(dbgs() << *BO << '\n'); - Worklist.push_back(std::make_pair(BO, Weight)); + Tmp = LowerNegateToMultiply(Tmp); + LLVM_DEBUG(dbgs() << *Tmp << '\n'); + Worklist.push_back(std::make_pair(Tmp, Weight)); Changed = true; continue; } @@ -862,6 +867,8 @@ static Value *NegateValue(Value *V, Instruction *BI, if (TheNeg->getParent()->getParent() != BI->getParent()->getParent()) continue; + bool FoundCatchSwitch = false; + BasicBlock::iterator InsertPt; if (Instruction *InstInput = dyn_cast<Instruction>(V)) { if (InvokeInst *II = dyn_cast<InvokeInst>(InstInput)) { @@ -869,10 +876,30 @@ static Value *NegateValue(Value *V, Instruction *BI, } else { InsertPt = ++InstInput->getIterator(); } - while (isa<PHINode>(InsertPt)) ++InsertPt; + + const BasicBlock *BB = InsertPt->getParent(); + + // Make sure we don't move anything before PHIs or exception + // handling pads. + while (InsertPt != BB->end() && (isa<PHINode>(InsertPt) || + InsertPt->isEHPad())) { + if (isa<CatchSwitchInst>(InsertPt)) + // A catchswitch cannot have anything in the block except + // itself and PHIs. We'll bail out below. + FoundCatchSwitch = true; + ++InsertPt; + } } else { InsertPt = TheNeg->getParent()->getParent()->getEntryBlock().begin(); } + + // We found a catchswitch in the block where we want to move the + // neg. We cannot move anything into that block. Bail and just + // create the neg before BI, as if we hadn't found an existing + // neg. + if (FoundCatchSwitch) + break; + TheNeg->moveBefore(&*InsertPt); if (TheNeg->getOpcode() == Instruction::Sub) { TheNeg->setHasNoUnsignedWrap(false); @@ -1329,8 +1356,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // So, if Rank(X) < Rank(Y) < Rank(Z), it means X is defined earlier // than Y which is defined earlier than Z. Permute "x | 1", "Y & 2", // "z" in the order of X-Y-Z is better than any other orders. - std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), - [](XorOpnd *LHS, XorOpnd *RHS) { + llvm::stable_sort(OpndPtrs, [](XorOpnd *LHS, XorOpnd *RHS) { return LHS->getSymbolicRank() < RHS->getSymbolicRank(); }); @@ -1687,8 +1713,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, // below our mininum of '4'. assert(FactorPowerSum >= 4); - std::stable_sort(Factors.begin(), Factors.end(), - [](const Factor &LHS, const Factor &RHS) { + llvm::stable_sort(Factors, [](const Factor &LHS, const Factor &RHS) { return LHS.Power > RHS.Power; }); return true; @@ -1801,7 +1826,7 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I, return V; ValueEntry NewEntry = ValueEntry(getRank(V), V); - Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry); + Ops.insert(llvm::lower_bound(Ops, NewEntry), NewEntry); return nullptr; } @@ -2001,7 +2026,7 @@ Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { /// instructions is not allowed. void ReassociatePass::OptimizeInst(Instruction *I) { // Only consider operations that we understand. - if (!isa<BinaryOperator>(I)) + if (!isa<UnaryOperator>(I) && !isa<BinaryOperator>(I)) return; if (I->getOpcode() == Instruction::Shl && isa<ConstantInt>(I->getOperand(1))) @@ -2066,7 +2091,8 @@ void ReassociatePass::OptimizeInst(Instruction *I) { I = NI; } } - } else if (I->getOpcode() == Instruction::FSub) { + } else if (I->getOpcode() == Instruction::FNeg || + I->getOpcode() == Instruction::FSub) { if (ShouldBreakUpSubtract(I)) { Instruction *NI = BreakUpSubtract(I, RedoInsts); RedoInsts.insert(I); @@ -2075,7 +2101,9 @@ void ReassociatePass::OptimizeInst(Instruction *I) { } else if (match(I, m_FNeg(m_Value()))) { // Otherwise, this is a negation. See if the operand is a multiply tree // and if this is not an inner node of a multiply tree. - if (isReassociableOp(I->getOperand(1), Instruction::FMul) && + Value *Op = isa<BinaryOperator>(I) ? I->getOperand(1) : + I->getOperand(0); + if (isReassociableOp(Op, Instruction::FMul) && (!I->hasOneUse() || !isReassociableOp(I->user_back(), Instruction::FMul))) { // If the negate was simplified, revisit the users to see if we can @@ -2142,7 +2170,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // positions maintained (and so the compiler is deterministic). Note that // this sorts so that the highest ranking values end up at the beginning of // the vector. - std::stable_sort(Ops.begin(), Ops.end()); + llvm::stable_sort(Ops); // Now that we have the expression tree in a convenient // sorted form, optimize it globally if possible. @@ -2218,8 +2246,15 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { if (std::less<Value *>()(Op1, Op0)) std::swap(Op0, Op1); auto it = PairMap[Idx].find({Op0, Op1}); - if (it != PairMap[Idx].end()) - Score += it->second; + if (it != PairMap[Idx].end()) { + // Functions like BreakUpSubtract() can erase the Values we're using + // as keys and create new Values after we built the PairMap. There's a + // small chance that the new nodes can have the same address as + // something already in the table. We shouldn't accumulate the stored + // score in that case as it refers to the wrong Value. + if (it->second.isValid()) + Score += it->second.Score; + } unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); if (Score > Max || (Score == Max && MaxRank < BestRank)) { @@ -2288,9 +2323,15 @@ ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { std::swap(Op0, Op1); if (!Visited.insert({Op0, Op1}).second) continue; - auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); - if (!res.second) - ++res.first->second; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, {Op0, Op1, 1}}); + if (!res.second) { + // If either key value has been erased then we've got the same + // address by coincidence. That can't happen here because nothing is + // erasing values but it can happen by the time we're querying the + // map. + assert(res.first->second.isValid() && "WeakVH invalidated"); + ++res.first->second.Score; + } } } } diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp index 018feb035a4f..3296322e00d5 100644 --- a/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -1,9 +1,8 @@ //===- Reg2Mem.cpp - Convert registers to allocas -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 42d7ed5bc534..c358258d24cf 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -1,9 +1,8 @@ //===- RewriteStatepointsForGC.cpp - Make GC relocations explicit ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -26,18 +25,17 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -286,9 +284,9 @@ struct PartiallyConstructedSafepointRecord { } // end anonymous namespace -static ArrayRef<Use> GetDeoptBundleOperands(ImmutableCallSite CS) { +static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) { Optional<OperandBundleUse> DeoptBundle = - CS.getOperandBundle(LLVMContext::OB_deopt); + Call->getOperandBundle(LLVMContext::OB_deopt); if (!DeoptBundle.hasValue()) { assert(AllowStatepointWithNoDeoptInfo && @@ -370,14 +368,11 @@ static std::string suffixed_name_or(Value *V, StringRef Suffix, // given instruction. The analysis is performed immediately before the // given instruction. Values defined by that instruction are not considered // live. Values used by that instruction are considered live. -static void -analyzeParsePointLiveness(DominatorTree &DT, - GCPtrLivenessData &OriginalLivenessData, CallSite CS, - PartiallyConstructedSafepointRecord &Result) { - Instruction *Inst = CS.getInstruction(); - +static void analyzeParsePointLiveness( + DominatorTree &DT, GCPtrLivenessData &OriginalLivenessData, CallBase *Call, + PartiallyConstructedSafepointRecord &Result) { StatepointLiveSetTy LiveSet; - findLiveSetAtInst(Inst, OriginalLivenessData, LiveSet); + findLiveSetAtInst(Call, OriginalLivenessData, LiveSet); if (PrintLiveSet) { dbgs() << "Live Variables:\n"; @@ -385,7 +380,7 @@ analyzeParsePointLiveness(DominatorTree &DT, dbgs() << " " << V->getName() << " " << *V << "\n"; } if (PrintLiveSetSize) { - dbgs() << "Safepoint For: " << CS.getCalledValue()->getName() << "\n"; + dbgs() << "Safepoint For: " << Call->getCalledValue()->getName() << "\n"; dbgs() << "Number live values: " << LiveSet.size() << "\n"; } Result.LiveSet = LiveSet; @@ -1178,7 +1173,7 @@ findBasePointers(const StatepointLiveSetTy &live, /// Find the required based pointers (and adjust the live set) for the given /// parse point. static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, - CallSite CS, + CallBase *Call, PartiallyConstructedSafepointRecord &result) { MapVector<Value *, Value *> PointerToBase; findBasePointers(result.LiveSet, PointerToBase, &DT, DVCache); @@ -1200,11 +1195,11 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, /// Given an updated version of the dataflow liveness results, update the /// liveset and base pointer maps for the call site CS. static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, - CallSite CS, + CallBase *Call, PartiallyConstructedSafepointRecord &result); static void recomputeLiveInValues( - Function &F, DominatorTree &DT, ArrayRef<CallSite> toUpdate, + Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate, MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) { // TODO-PERF: reuse the original liveness, then simply run the dataflow // again. The old values are still live and will help it stabilize quickly. @@ -1307,7 +1302,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, // Lazily populated map from input types to the canonicalized form mentioned // in the comment above. This should probably be cached somewhere more // broadly. - DenseMap<Type*, Value*> TypeToDeclMap; + DenseMap<Type *, Function *> TypeToDeclMap; for (unsigned i = 0; i < LiveVariables.size(); i++) { // Generate the gc.relocate call and save the result @@ -1318,7 +1313,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, Type *Ty = LiveVariables[i]->getType(); if (!TypeToDeclMap.count(Ty)) TypeToDeclMap[Ty] = getGCRelocateDecl(Ty); - Value *GCRelocateDecl = TypeToDeclMap[Ty]; + Function *GCRelocateDecl = TypeToDeclMap[Ty]; // only specify a debug name if we can give a useful one CallInst *Reloc = Builder.CreateCall( @@ -1399,16 +1394,16 @@ public: } // end anonymous namespace -static StringRef getDeoptLowering(CallSite CS) { +static StringRef getDeoptLowering(CallBase *Call) { const char *DeoptLowering = "deopt-lowering"; - if (CS.hasFnAttr(DeoptLowering)) { - // FIXME: CallSite has a *really* confusing interface around attributes + if (Call->hasFnAttr(DeoptLowering)) { + // FIXME: Calls have a *really* confusing interface around attributes // with values. - const AttributeList &CSAS = CS.getAttributes(); + const AttributeList &CSAS = Call->getAttributes(); if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering)) return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering) .getValueAsString(); - Function *F = CS.getCalledFunction(); + Function *F = Call->getCalledFunction(); assert(F && F->hasFnAttribute(DeoptLowering)); return F->getFnAttribute(DeoptLowering).getValueAsString(); } @@ -1416,7 +1411,7 @@ static StringRef getDeoptLowering(CallSite CS) { } static void -makeStatepointExplicitImpl(const CallSite CS, /* to replace */ +makeStatepointExplicitImpl(CallBase *Call, /* to replace */ const SmallVectorImpl<Value *> &BasePtrs, const SmallVectorImpl<Value *> &LiveVariables, PartiallyConstructedSafepointRecord &Result, @@ -1427,19 +1422,18 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // immediately before the previous instruction under the assumption that all // arguments will be available here. We can't insert afterwards since we may // be replacing a terminator. - Instruction *InsertBefore = CS.getInstruction(); - IRBuilder<> Builder(InsertBefore); + IRBuilder<> Builder(Call); ArrayRef<Value *> GCArgs(LiveVariables); uint64_t StatepointID = StatepointDirectives::DefaultStatepointID; uint32_t NumPatchBytes = 0; uint32_t Flags = uint32_t(StatepointFlags::None); - ArrayRef<Use> CallArgs(CS.arg_begin(), CS.arg_end()); - ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(CS); + ArrayRef<Use> CallArgs(Call->arg_begin(), Call->arg_end()); + ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(Call); ArrayRef<Use> TransitionArgs; if (auto TransitionBundle = - CS.getOperandBundle(LLVMContext::OB_gc_transition)) { + Call->getOperandBundle(LLVMContext::OB_gc_transition)) { Flags |= uint32_t(StatepointFlags::GCTransition); TransitionArgs = TransitionBundle->Inputs; } @@ -1450,21 +1444,21 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ bool IsDeoptimize = false; StatepointDirectives SD = - parseStatepointDirectivesFromAttrs(CS.getAttributes()); + parseStatepointDirectivesFromAttrs(Call->getAttributes()); if (SD.NumPatchBytes) NumPatchBytes = *SD.NumPatchBytes; if (SD.StatepointID) StatepointID = *SD.StatepointID; // Pass through the requested lowering if any. The default is live-through. - StringRef DeoptLowering = getDeoptLowering(CS); + StringRef DeoptLowering = getDeoptLowering(Call); if (DeoptLowering.equals("live-in")) Flags |= uint32_t(StatepointFlags::DeoptLiveIn); else { assert(DeoptLowering.equals("live-through") && "Unsupported value!"); } - Value *CallTarget = CS.getCalledValue(); + Value *CallTarget = Call->getCalledValue(); if (Function *F = dyn_cast<Function>(CallTarget)) { if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) { // Calls to llvm.experimental.deoptimize are lowered to calls to the @@ -1481,8 +1475,9 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // calls to @llvm.experimental.deoptimize with different argument types in // the same module. This is fine -- we assume the frontend knew what it // was doing when generating this kind of IR. - CallTarget = - F->getParent()->getOrInsertFunction("__llvm_deoptimize", FTy); + CallTarget = F->getParent() + ->getOrInsertFunction("__llvm_deoptimize", FTy) + .getCallee(); IsDeoptimize = true; } @@ -1490,57 +1485,56 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // Create the statepoint given all the arguments Instruction *Token = nullptr; - if (CS.isCall()) { - CallInst *ToReplace = cast<CallInst>(CS.getInstruction()); - CallInst *Call = Builder.CreateGCStatepointCall( + if (auto *CI = dyn_cast<CallInst>(Call)) { + CallInst *SPCall = Builder.CreateGCStatepointCall( StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs, TransitionArgs, DeoptArgs, GCArgs, "safepoint_token"); - Call->setTailCallKind(ToReplace->getTailCallKind()); - Call->setCallingConv(ToReplace->getCallingConv()); + SPCall->setTailCallKind(CI->getTailCallKind()); + SPCall->setCallingConv(CI->getCallingConv()); // Currently we will fail on parameter attributes and on certain // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - Call->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); + SPCall->setAttributes(legalizeCallAttributes(CI->getAttributes())); - Token = Call; + Token = SPCall; // Put the following gc_result and gc_relocate calls immediately after the // the old call (which we're about to delete) - assert(ToReplace->getNextNode() && "Not a terminator, must have next!"); - Builder.SetInsertPoint(ToReplace->getNextNode()); - Builder.SetCurrentDebugLocation(ToReplace->getNextNode()->getDebugLoc()); + assert(CI->getNextNode() && "Not a terminator, must have next!"); + Builder.SetInsertPoint(CI->getNextNode()); + Builder.SetCurrentDebugLocation(CI->getNextNode()->getDebugLoc()); } else { - InvokeInst *ToReplace = cast<InvokeInst>(CS.getInstruction()); + auto *II = cast<InvokeInst>(Call); // Insert the new invoke into the old block. We'll remove the old one in a // moment at which point this will become the new terminator for the // original block. - InvokeInst *Invoke = Builder.CreateGCStatepointInvoke( - StatepointID, NumPatchBytes, CallTarget, ToReplace->getNormalDest(), - ToReplace->getUnwindDest(), Flags, CallArgs, TransitionArgs, DeoptArgs, - GCArgs, "statepoint_token"); + InvokeInst *SPInvoke = Builder.CreateGCStatepointInvoke( + StatepointID, NumPatchBytes, CallTarget, II->getNormalDest(), + II->getUnwindDest(), Flags, CallArgs, TransitionArgs, DeoptArgs, GCArgs, + "statepoint_token"); - Invoke->setCallingConv(ToReplace->getCallingConv()); + SPInvoke->setCallingConv(II->getCallingConv()); // Currently we will fail on parameter attributes and on certain // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - Invoke->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); + SPInvoke->setAttributes(legalizeCallAttributes(II->getAttributes())); - Token = Invoke; + Token = SPInvoke; // Generate gc relocates in exceptional path - BasicBlock *UnwindBlock = ToReplace->getUnwindDest(); + BasicBlock *UnwindBlock = II->getUnwindDest(); assert(!isa<PHINode>(UnwindBlock->begin()) && UnwindBlock->getUniquePredecessor() && "can't safely insert in this block!"); Builder.SetInsertPoint(&*UnwindBlock->getFirstInsertionPt()); - Builder.SetCurrentDebugLocation(ToReplace->getDebugLoc()); + Builder.SetCurrentDebugLocation(II->getDebugLoc()); // Attach exceptional gc relocates to the landingpad. Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst(); @@ -1551,7 +1545,7 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ Builder); // Generate gc relocates and returns for normal block - BasicBlock *NormalDest = ToReplace->getNormalDest(); + BasicBlock *NormalDest = II->getNormalDest(); assert(!isa<PHINode>(NormalDest->begin()) && NormalDest->getUniquePredecessor() && "can't safely insert in this block!"); @@ -1568,16 +1562,15 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // transform the tail-call like structure to a call to a void function // followed by unreachable to get better codegen. Replacements.push_back( - DeferredReplacement::createDeoptimizeReplacement(CS.getInstruction())); + DeferredReplacement::createDeoptimizeReplacement(Call)); } else { Token->setName("statepoint_token"); - if (!CS.getType()->isVoidTy() && !CS.getInstruction()->use_empty()) { - StringRef Name = - CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : ""; - CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name); + if (!Call->getType()->isVoidTy() && !Call->use_empty()) { + StringRef Name = Call->hasName() ? Call->getName() : ""; + CallInst *GCResult = Builder.CreateGCResult(Token, Call->getType(), Name); GCResult->setAttributes( AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex, - CS.getAttributes().getRetAttributes())); + Call->getAttributes().getRetAttributes())); // We cannot RAUW or delete CS.getInstruction() because it could be in the // live set of some other safepoint, in which case that safepoint's @@ -1586,10 +1579,9 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // after the live sets have been made explicit in the IR, and we no longer // have raw pointers to worry about. Replacements.emplace_back( - DeferredReplacement::createRAUW(CS.getInstruction(), GCResult)); + DeferredReplacement::createRAUW(Call, GCResult)); } else { - Replacements.emplace_back( - DeferredReplacement::createDelete(CS.getInstruction())); + Replacements.emplace_back(DeferredReplacement::createDelete(Call)); } } @@ -1606,7 +1598,7 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ // WARNING: Does not do any fixup to adjust users of the original live // values. That's the callers responsibility. static void -makeStatepointExplicit(DominatorTree &DT, CallSite CS, +makeStatepointExplicit(DominatorTree &DT, CallBase *Call, PartiallyConstructedSafepointRecord &Result, std::vector<DeferredReplacement> &Replacements) { const auto &LiveSet = Result.LiveSet; @@ -1625,7 +1617,7 @@ makeStatepointExplicit(DominatorTree &DT, CallSite CS, assert(LiveVec.size() == BaseVec.size()); // Do the actual rewriting and delete the old statepoint - makeStatepointExplicitImpl(CS, BaseVec, LiveVec, Result, Replacements); + makeStatepointExplicitImpl(Call, BaseVec, LiveVec, Result, Replacements); } // Helper function for the relocationViaAlloca. @@ -1636,7 +1628,7 @@ makeStatepointExplicit(DominatorTree &DT, CallSite CS, // for sanity checking. static void insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, - DenseMap<Value *, Value *> &AllocaMap, + DenseMap<Value *, AllocaInst *> &AllocaMap, DenseSet<Value *> &VisitedLiveValues) { for (User *U : GCRelocs) { GCRelocateInst *Relocate = dyn_cast<GCRelocateInst>(U); @@ -1671,7 +1663,7 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, // "insertRelocationStores" but works for rematerialized values. static void insertRematerializationStores( const RematerializedValueMapTy &RematerializedValues, - DenseMap<Value *, Value *> &AllocaMap, + DenseMap<Value *, AllocaInst *> &AllocaMap, DenseSet<Value *> &VisitedLiveValues) { for (auto RematerializedValuePair: RematerializedValues) { Instruction *RematerializedValue = RematerializedValuePair.first; @@ -1704,7 +1696,7 @@ static void relocationViaAlloca( #endif // TODO-PERF: change data structures, reserve - DenseMap<Value *, Value *> AllocaMap; + DenseMap<Value *, AllocaInst *> AllocaMap; SmallVector<AllocaInst *, 200> PromotableAllocas; // Used later to chack that we have enough allocas to store all values std::size_t NumRematerializedValues = 0; @@ -1774,7 +1766,7 @@ static void relocationViaAlloca( SmallVector<AllocaInst *, 64> ToClobber; for (auto Pair : AllocaMap) { Value *Def = Pair.first; - AllocaInst *Alloca = cast<AllocaInst>(Pair.second); + AllocaInst *Alloca = Pair.second; // This value was relocated if (VisitedLiveValues.count(Def)) { @@ -1806,7 +1798,7 @@ static void relocationViaAlloca( // Update use with load allocas and add store for gc_relocated. for (auto Pair : AllocaMap) { Value *Def = Pair.first; - Value *Alloca = Pair.second; + AllocaInst *Alloca = Pair.second; // We pre-record the uses of allocas so that we dont have to worry about // later update that changes the user information.. @@ -1834,13 +1826,15 @@ static void relocationViaAlloca( PHINode *Phi = cast<PHINode>(Use); for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) { if (Def == Phi->getIncomingValue(i)) { - LoadInst *Load = new LoadInst( - Alloca, "", Phi->getIncomingBlock(i)->getTerminator()); + LoadInst *Load = + new LoadInst(Alloca->getAllocatedType(), Alloca, "", + Phi->getIncomingBlock(i)->getTerminator()); Phi->setIncomingValue(i, Load); } } } else { - LoadInst *Load = new LoadInst(Alloca, "", Use); + LoadInst *Load = + new LoadInst(Alloca->getAllocatedType(), Alloca, "", Use); Use->replaceUsesOfWith(Def, Load); } } @@ -1893,25 +1887,25 @@ template <typename T> static void unique_unsorted(SmallVectorImpl<T> &Vec) { /// Insert holders so that each Value is obviously live through the entire /// lifetime of the call. -static void insertUseHolderAfter(CallSite &CS, const ArrayRef<Value *> Values, +static void insertUseHolderAfter(CallBase *Call, const ArrayRef<Value *> Values, SmallVectorImpl<CallInst *> &Holders) { if (Values.empty()) // No values to hold live, might as well not insert the empty holder return; - Module *M = CS.getInstruction()->getModule(); + Module *M = Call->getModule(); // Use a dummy vararg function to actually hold the values live - Function *Func = cast<Function>(M->getOrInsertFunction( - "__tmp_use", FunctionType::get(Type::getVoidTy(M->getContext()), true))); - if (CS.isCall()) { + FunctionCallee Func = M->getOrInsertFunction( + "__tmp_use", FunctionType::get(Type::getVoidTy(M->getContext()), true)); + if (isa<CallInst>(Call)) { // For call safepoints insert dummy calls right after safepoint - Holders.push_back(CallInst::Create(Func, Values, "", - &*++CS.getInstruction()->getIterator())); + Holders.push_back( + CallInst::Create(Func, Values, "", &*++Call->getIterator())); return; } // For invoke safepooints insert dummy calls both in normal and // exceptional destination blocks - auto *II = cast<InvokeInst>(CS.getInstruction()); + auto *II = cast<InvokeInst>(Call); Holders.push_back(CallInst::Create( Func, Values, "", &*II->getNormalDest()->getFirstInsertionPt())); Holders.push_back(CallInst::Create( @@ -1919,7 +1913,7 @@ static void insertUseHolderAfter(CallSite &CS, const ArrayRef<Value *> Values, } static void findLiveReferences( - Function &F, DominatorTree &DT, ArrayRef<CallSite> toUpdate, + Function &F, DominatorTree &DT, ArrayRef<CallBase *> toUpdate, MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) { GCPtrLivenessData OriginalLivenessData; computeLiveInValues(DT, F, OriginalLivenessData); @@ -2022,7 +2016,7 @@ static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPh // to relocate. Remove this values from the live set, rematerialize them after // statepoint and record them in "Info" structure. Note that similar to // relocated values we don't do any user adjustments here. -static void rematerializeLiveValues(CallSite CS, +static void rematerializeLiveValues(CallBase *Call, PartiallyConstructedSafepointRecord &Info, TargetTransformInfo &TTI) { const unsigned int ChainLengthThreshold = 10; @@ -2076,7 +2070,7 @@ static void rematerializeLiveValues(CallSite CS, // For invokes we need to rematerialize each chain twice - for normal and // for unwind basic blocks. Model this by multiplying cost by two. - if (CS.isInvoke()) { + if (isa<InvokeInst>(Call)) { Cost *= 2; } // If it's too expensive - skip it @@ -2144,14 +2138,14 @@ static void rematerializeLiveValues(CallSite CS, // Different cases for calls and invokes. For invokes we need to clone // instructions both on normal and unwind path. - if (CS.isCall()) { - Instruction *InsertBefore = CS.getInstruction()->getNextNode(); + if (isa<CallInst>(Call)) { + Instruction *InsertBefore = Call->getNextNode(); assert(InsertBefore); Instruction *RematerializedValue = rematerializeChain( InsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); Info.RematerializedValues[RematerializedValue] = LiveValue; } else { - InvokeInst *Invoke = cast<InvokeInst>(CS.getInstruction()); + auto *Invoke = cast<InvokeInst>(Call); Instruction *NormalInsertBefore = &*Invoke->getNormalDest()->getFirstInsertionPt(); @@ -2176,25 +2170,25 @@ static void rematerializeLiveValues(CallSite CS, static bool insertParsePoints(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, - SmallVectorImpl<CallSite> &ToUpdate) { + SmallVectorImpl<CallBase *> &ToUpdate) { #ifndef NDEBUG // sanity check the input - std::set<CallSite> Uniqued; + std::set<CallBase *> Uniqued; Uniqued.insert(ToUpdate.begin(), ToUpdate.end()); assert(Uniqued.size() == ToUpdate.size() && "no duplicates please!"); - for (CallSite CS : ToUpdate) - assert(CS.getInstruction()->getFunction() == &F); + for (CallBase *Call : ToUpdate) + assert(Call->getFunction() == &F); #endif // When inserting gc.relocates for invokes, we need to be able to insert at // the top of the successor blocks. See the comment on // normalForInvokeSafepoint on exactly what is needed. Note that this step // may restructure the CFG. - for (CallSite CS : ToUpdate) { - if (!CS.isInvoke()) + for (CallBase *Call : ToUpdate) { + auto *II = dyn_cast<InvokeInst>(Call); + if (!II) continue; - auto *II = cast<InvokeInst>(CS.getInstruction()); normalizeForInvokeSafepoint(II->getNormalDest(), II->getParent(), DT); normalizeForInvokeSafepoint(II->getUnwindDest(), II->getParent(), DT); } @@ -2207,17 +2201,17 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // actual safepoint insertion as arguments. This ensures reference operands // in the deopt argument list are considered live through the safepoint (and // thus makes sure they get relocated.) - for (CallSite CS : ToUpdate) { + for (CallBase *Call : ToUpdate) { SmallVector<Value *, 64> DeoptValues; - for (Value *Arg : GetDeoptBundleOperands(CS)) { + for (Value *Arg : GetDeoptBundleOperands(Call)) { assert(!isUnhandledGCPointerType(Arg->getType()) && "support for FCA unimplemented"); if (isHandledGCPointerType(Arg->getType())) DeoptValues.push_back(Arg); } - insertUseHolderAfter(CS, DeoptValues, Holders); + insertUseHolderAfter(Call, DeoptValues, Holders); } SmallVector<PartiallyConstructedSafepointRecord, 64> Records(ToUpdate.size()); @@ -2319,7 +2313,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, for (size_t i = 0; i < Records.size(); i++) makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements); - ToUpdate.clear(); // prevent accident use of invalid CallSites + ToUpdate.clear(); // prevent accident use of invalid calls. for (auto &PR : Replacements) PR.doReplacement(); @@ -2384,7 +2378,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, return !Records.empty(); } -// Handles both return values and arguments for Functions and CallSites. +// Handles both return values and arguments for Functions and calls. template <typename AttrHolder> static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH, unsigned Index) { @@ -2476,12 +2470,13 @@ static void stripNonValidDataFromBody(Function &F) { stripInvalidMetadataFromInstruction(I); - if (CallSite CS = CallSite(&I)) { - for (int i = 0, e = CS.arg_size(); i != e; i++) - if (isa<PointerType>(CS.getArgument(i)->getType())) - RemoveNonValidAttrAtIndex(Ctx, CS, i + AttributeList::FirstArgIndex); - if (isa<PointerType>(CS.getType())) - RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex); + if (auto *Call = dyn_cast<CallBase>(&I)) { + for (int i = 0, e = Call->arg_size(); i != e; i++) + if (isa<PointerType>(Call->getArgOperand(i)->getType())) + RemoveNonValidAttrAtIndex(Ctx, *Call, + i + AttributeList::FirstArgIndex); + if (isa<PointerType>(Call->getType())) + RemoveNonValidAttrAtIndex(Ctx, *Call, AttributeList::ReturnIndex); } } @@ -2526,12 +2521,11 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, assert(shouldRewriteStatepointsIn(F) && "mismatch in rewrite decision"); auto NeedsRewrite = [&TLI](Instruction &I) { - if (ImmutableCallSite CS = ImmutableCallSite(&I)) - return !callsGCLeafFunction(CS, TLI) && !isStatepoint(CS); + if (const auto *Call = dyn_cast<CallBase>(&I)) + return !callsGCLeafFunction(Call, TLI) && !isStatepoint(Call); return false; }; - // Delete any unreachable statepoints so that we don't have unrewritten // statepoints surviving this pass. This makes testing easier and the // resulting IR less confusing to human readers. @@ -2543,7 +2537,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // Gather all the statepoints which need rewritten. Be careful to only // consider those in reachable code since we need to ask dominance queries // when rewriting. We'll delete the unreachable ones in a moment. - SmallVector<CallSite, 64> ParsePointNeeded; + SmallVector<CallBase *, 64> ParsePointNeeded; for (Instruction &I : instructions(F)) { // TODO: only the ones with the flag set! if (NeedsRewrite(I)) { @@ -2553,7 +2547,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // isReachableFromEntry() returns true. assert(DT.isReachableFromEntry(I.getParent()) && "no unreachable blocks expected"); - ParsePointNeeded.push_back(CallSite(&I)); + ParsePointNeeded.push_back(cast<CallBase>(&I)); } } @@ -2602,6 +2596,33 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, } } + // Nasty workaround - The base computation code in the main algorithm doesn't + // consider the fact that a GEP can be used to convert a scalar to a vector. + // The right fix for this is to integrate GEPs into the base rewriting + // algorithm properly, this is just a short term workaround to prevent + // crashes by canonicalizing such GEPs into fully vector GEPs. + for (Instruction &I : instructions(F)) { + if (!isa<GetElementPtrInst>(I)) + continue; + + unsigned VF = 0; + for (unsigned i = 0; i < I.getNumOperands(); i++) + if (I.getOperand(i)->getType()->isVectorTy()) { + assert(VF == 0 || + VF == I.getOperand(i)->getType()->getVectorNumElements()); + VF = I.getOperand(i)->getType()->getVectorNumElements(); + } + + // It's the vector to scalar traversal through the pointer operand which + // confuses base pointer rewriting, so limit ourselves to that case. + if (!I.getOperand(0)->getType()->isVectorTy() && VF != 0) { + IRBuilder<> B(&I); + auto *Splat = B.CreateVectorSplat(VF, I.getOperand(0)); + I.setOperand(0, Splat); + MadeChange = true; + } + } + MadeChange |= insertParsePoints(F, DT, TTI, ParsePointNeeded); return MadeChange; } @@ -2786,11 +2807,10 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, } static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, - CallSite CS, + CallBase *Call, PartiallyConstructedSafepointRecord &Info) { - Instruction *Inst = CS.getInstruction(); StatepointLiveSetTy Updated; - findLiveSetAtInst(Inst, RevisedLivenessData, Updated); + findLiveSetAtInst(Call, RevisedLivenessData, Updated); // We may have base pointers which are now live that weren't before. We need // to update the PointerToBase structure to reflect this. diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 2f6ed05c023b..4093e50ce899 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -1,9 +1,8 @@ //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -210,11 +210,11 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { /// TrackedRetVals - If we are tracking arguments into and the return /// value out of a function, it will have an entry in this map, indicating /// what the known return value for the function is. - DenseMap<Function *, LatticeVal> TrackedRetVals; + MapVector<Function *, LatticeVal> TrackedRetVals; /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions /// that return multiple values. - DenseMap<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; + MapVector<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. @@ -372,7 +372,7 @@ public: } /// getTrackedRetVals - Get the inferred return value map. - const DenseMap<Function*, LatticeVal> &getTrackedRetVals() { + const MapVector<Function*, LatticeVal> &getTrackedRetVals() { return TrackedRetVals; } @@ -614,6 +614,7 @@ private: void visitCastInst(CastInst &I); void visitSelectInst(SelectInst &I); + void visitUnaryOperator(Instruction &I); void visitBinaryOperator(Instruction &I); void visitCmpInst(CmpInst &I); void visitExtractValueInst(ExtractValueInst &EVI); @@ -639,6 +640,11 @@ private: visitTerminator(II); } + void visitCallBrInst (CallBrInst &CBI) { + visitCallSite(&CBI); + visitTerminator(CBI); + } + void visitCallSite (CallSite CS); void visitResumeInst (ResumeInst &I) { /*returns void*/ } void visitUnreachableInst(UnreachableInst &I) { /*returns void*/ } @@ -734,6 +740,13 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI, return; } + // In case of callbr, we pessimistically assume that all successors are + // feasible. + if (isa<CallBrInst>(&TI)) { + Succs.assign(TI.getNumSuccessors(), true); + return; + } + LLVM_DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } @@ -825,7 +838,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) { // If we are tracking the return value of this function, merge it in. if (!TrackedRetVals.empty() && !ResultOp->getType()->isStructTy()) { - DenseMap<Function*, LatticeVal>::iterator TFRVI = + MapVector<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); if (TFRVI != TrackedRetVals.end()) { mergeInValue(TFRVI->second, F, getValueState(ResultOp)); @@ -958,6 +971,29 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { markOverdefined(&I); } +// Handle Unary Operators. +void SCCPSolver::visitUnaryOperator(Instruction &I) { + LatticeVal V0State = getValueState(I.getOperand(0)); + + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + if (V0State.isConstant()) { + Constant *C = ConstantExpr::get(I.getOpcode(), V0State.getConstant()); + + // op Y -> undef. + if (isa<UndefValue>(C)) + return; + return (void)markConstant(IV, &I, C); + } + + // If something is undef, wait for it to resolve. + if (!V0State.isOverdefined()) + return; + + markOverdefined(&I); +} + // Handle Binary Operators. void SCCPSolver::visitBinaryOperator(Instruction &I) { LatticeVal V1State = getValueState(I.getOperand(0)); @@ -1232,7 +1268,7 @@ CallOverdefined: // Otherwise, if we have a single return value case, and if the function is // a declaration, maybe we can constant fold it. if (F && F->isDeclaration() && !I->getType()->isStructTy() && - canConstantFoldCallTo(CS, F)) { + canConstantFoldCallTo(cast<CallBase>(CS.getInstruction()), F)) { SmallVector<Constant*, 8> Operands; for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); AI != E; ++AI) { @@ -1253,7 +1289,8 @@ CallOverdefined: // If we can constant fold this, mark the result of the call as a // constant. - if (Constant *C = ConstantFoldCall(CS, F, Operands, TLI)) { + if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F, + Operands, TLI)) { // call -> undef. if (isa<UndefValue>(C)) return; @@ -1315,7 +1352,7 @@ CallOverdefined: mergeInValue(getStructValueState(I, i), I, TrackedMultipleRetVals[std::make_pair(F, i)]); } else { - DenseMap<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); + MapVector<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); if (TFRVI == TrackedRetVals.end()) goto CallOverdefined; // Not tracking this callee. @@ -1472,6 +1509,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { else markOverdefined(&I); return true; + case Instruction::FNeg: + break; // fneg undef -> undef case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -1598,6 +1637,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return true; case Instruction::Call: case Instruction::Invoke: + case Instruction::CallBr: // There are two reasons a call can have an undef result // 1. It could be tracked. // 2. It could be constant-foldable. @@ -2070,12 +2110,22 @@ bool llvm::runIPSCCP( // If we have forced an edge for an indeterminate value, then force the // terminator to fold to that edge. forceIndeterminateEdge(I, Solver); - bool Folded = ConstantFoldTerminator(I->getParent(), + BasicBlock *InstBB = I->getParent(); + bool Folded = ConstantFoldTerminator(InstBB, /*DeleteDeadConditions=*/false, /*TLI=*/nullptr, &DTU); assert(Folded && "Expect TermInst on constantint or blockaddress to be folded"); (void) Folded; + // If we folded the terminator to an unconditional branch to another + // dead block, replace it with Unreachable, to avoid trying to fold that + // branch again. + BranchInst *BI = cast<BranchInst>(InstBB->getTerminator()); + if (BI && BI->isUnconditional() && + !Solver.isBlockExecutable(BI->getSuccessor(0))) { + InstBB->getTerminator()->eraseFromParent(); + new UnreachableInst(InstBB->getContext(), InstBB); + } } // Mark dead BB for deletion. DTU.deleteBB(DeadBB); @@ -2109,7 +2159,7 @@ bool llvm::runIPSCCP( // whether other functions are optimizable. SmallVector<ReturnInst*, 8> ReturnsToZap; - const DenseMap<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); + const MapVector<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); for (const auto &I : RV) { Function *F = I.first; if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index eab77cf4cda9..33f90d0b01e4 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -1,9 +1,8 @@ //===- SROA.cpp - Scalar Replacement Of Aggregates ------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file @@ -222,13 +221,6 @@ public: } // end anonymous namespace -namespace llvm { - -template <typename T> struct isPodLike; -template <> struct isPodLike<Slice> { static const bool value = true; }; - -} // end namespace llvm - /// Representation of the alloca slices. /// /// This class represents the slices of an alloca which are formed by its @@ -721,6 +713,13 @@ private: return Base::visitBitCastInst(BC); } + void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) { + if (ASC.use_empty()) + return markAsDead(ASC); + + return Base::visitAddrSpaceCastInst(ASC); + } + void visitGetElementPtrInst(GetElementPtrInst &GEPI) { if (GEPI.use_empty()) return markAsDead(GEPI); @@ -784,7 +783,10 @@ private: if (!IsOffsetKnown) return PI.setAborted(&LI); - const DataLayout &DL = LI.getModule()->getDataLayout(); + if (LI.isVolatile() && + LI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) + return PI.setAborted(&LI); + uint64_t Size = DL.getTypeStoreSize(LI.getType()); return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); } @@ -796,7 +798,10 @@ private: if (!IsOffsetKnown) return PI.setAborted(&SI); - const DataLayout &DL = SI.getModule()->getDataLayout(); + if (SI.isVolatile() && + SI.getPointerAddressSpace() != DL.getAllocaAddrSpace()) + return PI.setAborted(&SI); + uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); // If this memory access can be shown to *statically* extend outside the @@ -831,6 +836,11 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); + // Don't replace this with a store with a different address space. TODO: + // Use a store with the casted new alloca? + if (II.isVolatile() && II.getDestAddressSpace() != DL.getAllocaAddrSpace()) + return PI.setAborted(&II); + insertUse(II, Offset, Length ? Length->getLimitedValue() : AllocSize - Offset.getLimitedValue(), (bool)Length); @@ -850,6 +860,13 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); + // Don't replace this with a load/store with a different address space. + // TODO: Use a store with the casted new alloca? + if (II.isVolatile() && + (II.getDestAddressSpace() != DL.getAllocaAddrSpace() || + II.getSourceAddressSpace() != DL.getAllocaAddrSpace())) + return PI.setAborted(&II); + // This side of the transfer is completely out-of-bounds, and so we can // nuke the entire transfer. However, we also need to nuke the other side // if already added to our partitions. @@ -957,7 +974,7 @@ private: if (!GEP->hasAllZeroIndices()) return GEP; } else if (!isa<BitCastInst>(I) && !isa<PHINode>(I) && - !isa<SelectInst>(I)) { + !isa<SelectInst>(I) && !isa<AddrSpaceCastInst>(I)) { return I; } @@ -1173,12 +1190,16 @@ static Type *findCommonType(AllocaSlices::const_iterator B, /// FIXME: This should be hoisted into a generic utility, likely in /// Transforms/Util/Local.h static bool isSafePHIToSpeculate(PHINode &PN) { + const DataLayout &DL = PN.getModule()->getDataLayout(); + // For now, we can only do this promotion if the load is in the same block // as the PHI, and if there are no stores between the phi and load. // TODO: Allow recursive phi users. // TODO: Allow stores. BasicBlock *BB = PN.getParent(); unsigned MaxAlign = 0; + uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType()); + APInt MaxSize(APWidth, 0); bool HaveLoad = false; for (User *U : PN.users()) { LoadInst *LI = dyn_cast<LoadInst>(U); @@ -1197,15 +1218,15 @@ static bool isSafePHIToSpeculate(PHINode &PN) { if (BBI->mayWriteToMemory()) return false; + uint64_t Size = DL.getTypeStoreSizeInBits(LI->getType()); MaxAlign = std::max(MaxAlign, LI->getAlignment()); + MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize; HaveLoad = true; } if (!HaveLoad) return false; - const DataLayout &DL = PN.getModule()->getDataLayout(); - // We can only transform this if it is safe to push the loads into the // predecessor blocks. The only thing to watch out for is that we can't put // a possibly trapping load in the predecessor if it is a critical edge. @@ -1227,7 +1248,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // If this pointer is always safe to load, or if we can prove that there // is already a load in the block, then we can move the load to the pred // block. - if (isSafeToLoadUnconditionally(InVal, MaxAlign, DL, TI)) + if (isSafeToLoadUnconditionally(InVal, MaxAlign, MaxSize, DL, TI)) continue; return false; @@ -1239,15 +1260,14 @@ static bool isSafePHIToSpeculate(PHINode &PN) { static void speculatePHINodeLoads(PHINode &PN) { LLVM_DEBUG(dbgs() << " original: " << PN << "\n"); - Type *LoadTy = cast<PointerType>(PN.getType())->getElementType(); + LoadInst *SomeLoad = cast<LoadInst>(PN.user_back()); + Type *LoadTy = SomeLoad->getType(); IRBuilderTy PHIBuilder(&PN); PHINode *NewPN = PHIBuilder.CreatePHI(LoadTy, PN.getNumIncomingValues(), PN.getName() + ".sroa.speculated"); // Get the AA tags and alignment to use from one of the loads. It doesn't // matter which one we get and if any differ. - LoadInst *SomeLoad = cast<LoadInst>(PN.user_back()); - AAMDNodes AATags; SomeLoad->getAAMetadata(AATags); unsigned Align = SomeLoad->getAlignment(); @@ -1278,7 +1298,8 @@ static void speculatePHINodeLoads(PHINode &PN) { IRBuilderTy PredBuilder(TI); LoadInst *Load = PredBuilder.CreateLoad( - InVal, (PN.getName() + ".sroa.speculate.load." + Pred->getName())); + LoadTy, InVal, + (PN.getName() + ".sroa.speculate.load." + Pred->getName())); ++NumLoadsSpeculated; Load->setAlignment(Align); if (AATags) @@ -1317,9 +1338,11 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { // Both operands to the select need to be dereferenceable, either // absolutely (e.g. allocas) or at this point because we can see other // accesses to it. - if (!isSafeToLoadUnconditionally(TValue, LI->getAlignment(), DL, LI)) + if (!isSafeToLoadUnconditionally(TValue, LI->getType(), LI->getAlignment(), + DL, LI)) return false; - if (!isSafeToLoadUnconditionally(FValue, LI->getAlignment(), DL, LI)) + if (!isSafeToLoadUnconditionally(FValue, LI->getType(), LI->getAlignment(), + DL, LI)) return false; } @@ -1338,10 +1361,10 @@ static void speculateSelectInstLoads(SelectInst &SI) { assert(LI->isSimple() && "We only speculate simple loads"); IRB.SetInsertPoint(LI); - LoadInst *TL = - IRB.CreateLoad(TV, LI->getName() + ".sroa.speculate.load.true"); - LoadInst *FL = - IRB.CreateLoad(FV, LI->getName() + ".sroa.speculate.load.false"); + LoadInst *TL = IRB.CreateLoad(LI->getType(), TV, + LI->getName() + ".sroa.speculate.load.true"); + LoadInst *FL = IRB.CreateLoad(LI->getType(), FV, + LI->getName() + ".sroa.speculate.load.false"); NumLoadsSpeculated += 2; // Transfer alignment and AA info if present. @@ -1379,8 +1402,8 @@ static Value *buildGEP(IRBuilderTy &IRB, Value *BasePtr, if (Indices.size() == 1 && cast<ConstantInt>(Indices.back())->isZero()) return BasePtr; - return IRB.CreateInBoundsGEP(nullptr, BasePtr, Indices, - NamePrefix + "sroa_idx"); + return IRB.CreateInBoundsGEP(BasePtr->getType()->getPointerElementType(), + BasePtr, Indices, NamePrefix + "sroa_idx"); } /// Get a natural GEP off of the BasePtr walking through Ty toward @@ -1569,7 +1592,14 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, Value *Int8Ptr = nullptr; APInt Int8PtrOffset(Offset.getBitWidth(), 0); - Type *TargetTy = PointerTy->getPointerElementType(); + PointerType *TargetPtrTy = cast<PointerType>(PointerTy); + Type *TargetTy = TargetPtrTy->getElementType(); + + // As `addrspacecast` is , `Ptr` (the storage pointer) may have different + // address space from the expected `PointerTy` (the pointer to be used). + // Adjust the pointer type based the original storage pointer. + auto AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); + PointerTy = TargetTy->getPointerTo(AS); do { // First fold any existing GEPs into the offset. @@ -1599,7 +1629,7 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, OffsetBasePtr = Ptr; // If we also found a pointer of the right type, we're done. if (P->getType() == PointerTy) - return P; + break; } // Stash this pointer if we've found an i8*. @@ -1638,8 +1668,11 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, Ptr = OffsetPtr; // On the off chance we were targeting i8*, guard the bitcast here. - if (Ptr->getType() != PointerTy) - Ptr = IRB.CreateBitCast(Ptr, PointerTy, NamePrefix + "sroa_cast"); + if (cast<PointerType>(Ptr->getType()) != TargetPtrTy) { + Ptr = IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr, + TargetPtrTy, + NamePrefix + "sroa_cast"); + } return Ptr; } @@ -2418,14 +2451,16 @@ private: unsigned EndIndex = getIndex(NewEndOffset); assert(EndIndex > BeginIndex && "Empty vector!"); - Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "load"); return extractVector(IRB, V, BeginIndex, EndIndex, "vec"); } Value *rewriteIntegerLoad(LoadInst &LI) { assert(IntTy && "We cannot insert an integer to the alloca"); assert(!LI.isVolatile()); - Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "load"); V = convertValue(DL, IRB, V, IntTy); assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; @@ -2469,7 +2504,8 @@ private: (canConvertValue(DL, NewAllocaTy, TargetTy) || (IsLoadPastEnd && NewAllocaTy->isIntegerTy() && TargetTy->isIntegerTy()))) { - LoadInst *NewLI = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), + LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), LI.isVolatile(), LI.getName()); if (AATags) NewLI->setAAMetadata(AATags); @@ -2505,9 +2541,9 @@ private: } } else { Type *LTy = TargetTy->getPointerTo(AS); - LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy), - getSliceAlign(TargetTy), - LI.isVolatile(), LI.getName()); + LoadInst *NewLI = IRB.CreateAlignedLoad( + TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy), + LI.isVolatile(), LI.getName()); if (AATags) NewLI->setAAMetadata(AATags); if (LI.isVolatile()) @@ -2524,8 +2560,7 @@ private: "Only integer type loads and stores are split"); assert(SliceSize < DL.getTypeStoreSize(LI.getType()) && "Split load isn't smaller than original load"); - assert(LI.getType()->getIntegerBitWidth() == - DL.getTypeStoreSizeInBits(LI.getType()) && + assert(DL.typeSizeEqualsStoreSize(LI.getType()) && "Non-byte-multiple bit width"); // Move the insertion point just past the load so that we can refer to it. IRB.SetInsertPoint(&*std::next(BasicBlock::iterator(&LI))); @@ -2533,8 +2568,8 @@ private: // basis for the new value. This allows us to replace the uses of LI with // the computed value, and then replace the placeholder with LI, leaving // LI only used for this computation. - Value *Placeholder = - new LoadInst(UndefValue::get(LI.getType()->getPointerTo(AS))); + Value *Placeholder = new LoadInst( + LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS))); V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, "insert"); LI.replaceAllUsesWith(V); @@ -2565,7 +2600,8 @@ private: V = convertValue(DL, IRB, V, SliceTy); // Mix in the existing elements. - Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "load"); V = insertVector(IRB, Old, V, BeginIndex, "vec"); } StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); @@ -2581,8 +2617,8 @@ private: assert(IntTy && "We cannot extract an integer from the alloca"); assert(!SI.isVolatile()); if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { - Value *Old = - IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); uint64_t Offset = BeginOffset - NewAllocaBeginOffset; @@ -2619,8 +2655,7 @@ private: assert(!SI.isVolatile()); assert(V->getType()->isIntegerTy() && "Only integer type loads and stores are split"); - assert(V->getType()->getIntegerBitWidth() == - DL.getTypeStoreSizeInBits(V->getType()) && + assert(DL.typeSizeEqualsStoreSize(V->getType()) && "Non-byte-multiple bit width"); IntegerType *NarrowTy = Type::getIntNTy(SI.getContext(), SliceSize * 8); V = extractInteger(DL, IRB, V, NarrowTy, NewBeginOffset - BeginOffset, @@ -2731,15 +2766,26 @@ private: Type *AllocaTy = NewAI.getAllocatedType(); Type *ScalarTy = AllocaTy->getScalarType(); + + const bool CanContinue = [&]() { + if (VecTy || IntTy) + return true; + if (BeginOffset > NewAllocaBeginOffset || + EndOffset < NewAllocaEndOffset) + return false; + auto *C = cast<ConstantInt>(II.getLength()); + if (C->getBitWidth() > 64) + return false; + const auto Len = C->getZExtValue(); + auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext()); + auto *SrcTy = VectorType::get(Int8Ty, Len); + return canConvertValue(DL, SrcTy, AllocaTy) && + DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)); + }(); // If this doesn't map cleanly onto the alloca type, and that type isn't // a single value type, just emit a memset. - if (!VecTy && !IntTy && - (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || - SliceSize != DL.getTypeStoreSize(AllocaTy) || - !AllocaTy->isSingleValueType() || - !DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)) || - DL.getTypeSizeInBits(ScalarTy) % 8 != 0)) { + if (!CanContinue) { Type *SizeTy = II.getLength()->getType(); Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); CallInst *New = IRB.CreateMemSet( @@ -2774,8 +2820,8 @@ private: if (NumElements > 1) Splat = getVectorSplat(Splat, NumElements); - Value *Old = - IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "oldload"); V = insertVector(IRB, Old, Splat, BeginIndex, "vec"); } else if (IntTy) { // If this is a memset on an alloca where we can widen stores, insert the @@ -2787,8 +2833,8 @@ private: if (IntTy && (BeginOffset != NewAllocaBeginOffset || EndOffset != NewAllocaBeginOffset)) { - Value *Old = - IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; V = insertInteger(DL, IRB, Old, V, Offset, "insert"); @@ -2948,18 +2994,18 @@ private: // Reset the other pointer type to match the register type we're going to // use, but using the address space of the original other pointer. + Type *OtherTy; if (VecTy && !IsWholeAlloca) { if (NumElements == 1) - OtherPtrTy = VecTy->getElementType(); + OtherTy = VecTy->getElementType(); else - OtherPtrTy = VectorType::get(VecTy->getElementType(), NumElements); - - OtherPtrTy = OtherPtrTy->getPointerTo(OtherAS); + OtherTy = VectorType::get(VecTy->getElementType(), NumElements); } else if (IntTy && !IsWholeAlloca) { - OtherPtrTy = SubIntTy->getPointerTo(OtherAS); + OtherTy = SubIntTy; } else { - OtherPtrTy = NewAllocaTy->getPointerTo(OtherAS); + OtherTy = NewAllocaTy; } + OtherPtrTy = OtherTy->getPointerTo(OtherAS); Value *SrcPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, OtherPtr->getName() + "."); @@ -2973,28 +3019,30 @@ private: Value *Src; if (VecTy && !IsWholeAlloca && !IsDest) { - Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "load"); Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec"); } else if (IntTy && !IsWholeAlloca && !IsDest) { - Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "load"); Src = convertValue(DL, IRB, Src, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); } else { - LoadInst *Load = IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), - "copyload"); + LoadInst *Load = IRB.CreateAlignedLoad(OtherTy, SrcPtr, SrcAlign, + II.isVolatile(), "copyload"); if (AATags) Load->setAAMetadata(AATags); Src = Load; } if (VecTy && !IsWholeAlloca && IsDest) { - Value *Old = - IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "oldload"); Src = insertVector(IRB, Old, Src, BeginIndex, "vec"); } else if (IntTy && !IsWholeAlloca && IsDest) { - Value *Old = - IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, + NewAI.getAlignment(), "oldload"); Old = convertValue(DL, IRB, Old, IntTy); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = insertInteger(DL, IRB, Old, Src, Offset, "insert"); @@ -3031,7 +3079,10 @@ private: ConstantInt *Size = ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()), NewEndOffset - NewBeginOffset); - Value *Ptr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); + // Lifetime intrinsics always expect an i8* so directly get such a pointer + // for the new alloca slice. + Type *PointerTy = IRB.getInt8PtrTy(OldPtr->getType()->getPointerAddressSpace()); + Value *Ptr = getNewAllocaSlicePtr(IRB, PointerTy); Value *New; if (II.getIntrinsicID() == Intrinsic::lifetime_start) New = IRB.CreateLifetimeStart(Ptr, Size); @@ -3072,8 +3123,9 @@ private: continue; } - assert(isa<BitCastInst>(I) || isa<PHINode>(I) || - isa<SelectInst>(I) || isa<GetElementPtrInst>(I)); + assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) || + isa<PHINode>(I) || isa<SelectInst>(I) || + isa<GetElementPtrInst>(I)); for (User *U : I->users()) if (Visited.insert(cast<Instruction>(U)).second) Uses.push_back(cast<Instruction>(U)); @@ -3297,8 +3349,8 @@ private: assert(Ty->isSingleValueType()); // Load the single value and insert it using the indices. Value *GEP = - IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); - LoadInst *Load = IRB.CreateAlignedLoad(GEP, Align, Name + ".load"); + IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep"); + LoadInst *Load = IRB.CreateAlignedLoad(Ty, GEP, Align, Name + ".load"); if (AATags) Load->setAAMetadata(AATags); Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); @@ -3342,7 +3394,7 @@ private: Value *ExtractValue = IRB.CreateExtractValue(Agg, Indices, Name + ".extract"); Value *InBoundsGEP = - IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); + IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep"); StoreInst *Store = IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Align); if (AATags) @@ -3374,6 +3426,11 @@ private: return false; } + bool visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) { + enqueueUsers(ASC); + return false; + } + bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { enqueueUsers(GEPI); return false; @@ -3792,6 +3849,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { auto AS = LI->getPointerAddressSpace(); auto *PartPtrTy = PartTy->getPointerTo(AS); LoadInst *PLoad = IRB.CreateAlignedLoad( + PartTy, getAdjustedPtr(IRB, DL, BasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, BasePtr->getName() + "."), @@ -3933,6 +3991,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { IRB.SetInsertPoint(LI); auto AS = LI->getPointerAddressSpace(); PLoad = IRB.CreateAlignedLoad( + PartTy, getAdjustedPtr(IRB, DL, LoadBasePtr, APInt(DL.getIndexSizeInBits(AS), PartOffset), LoadPartPtrTy, LoadBasePtr->getName() + "."), diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index 976daf4c78c2..869cf00e0a89 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -1,9 +1,8 @@ //===-- Scalar.cpp --------------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -63,6 +62,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeJumpThreadingPass(Registry); initializeLegacyLICMPassPass(Registry); initializeLegacyLoopSinkPassPass(Registry); + initializeLoopFuseLegacyPass(Registry); initializeLoopDataPrefetchLegacyPassPass(Registry); initializeLoopDeletionLegacyPassPass(Registry); initializeLoopAccessLegacyAnalysisPass(Registry); @@ -81,8 +81,9 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLowerAtomicLegacyPassPass(Registry); initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); + initializeLowerWidenableConditionLegacyPassPass(Registry); initializeMemCpyOptLegacyPassPass(Registry); - initializeMergeICmpsPass(Registry); + initializeMergeICmpsLegacyPassPass(Registry); initializeMergedLoadStoreMotionLegacyPassPass(Registry); initializeNaryReassociateLegacyPassPass(Registry); initializePartiallyInlineLibCallsLegacyPassPass(Registry); diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index 5eb3fdab6d5c..2ee1a3a95f2a 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -1,9 +1,8 @@ //===- Scalarizer.cpp - Scalarize vector operations -----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -125,6 +124,18 @@ struct ICmpSplitter { ICmpInst &ICI; }; +// UnarySpliiter(UO)(Builder, X, Name) uses Builder to create +// a unary operator like UO called Name with operand X. +struct UnarySplitter { + UnarySplitter(UnaryOperator &uo) : UO(uo) {} + + Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const { + return Builder.CreateUnOp(UO.getOpcode(), Op, Name); + } + + UnaryOperator &UO; +}; + // BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create // a binary operator like BO called Name with operands X and Y. struct BinarySplitter { @@ -174,6 +185,7 @@ public: bool visitSelectInst(SelectInst &SI); bool visitICmpInst(ICmpInst &ICI); bool visitFCmpInst(FCmpInst &FCI); + bool visitUnaryOperator(UnaryOperator &UO); bool visitBinaryOperator(BinaryOperator &BO); bool visitGetElementPtrInst(GetElementPtrInst &GEPI); bool visitCastInst(CastInst &CI); @@ -188,11 +200,12 @@ private: Scatterer scatter(Instruction *Point, Value *V); void gather(Instruction *Op, const ValueVector &CV); bool canTransferMetadata(unsigned Kind); - void transferMetadata(Instruction *Op, const ValueVector &CV); + void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, const DataLayout &DL); bool finish(); + template<typename T> bool splitUnary(Instruction &, const T &); template<typename T> bool splitBinary(Instruction &, const T &); bool splitCall(CallInst &CI); @@ -246,14 +259,13 @@ Value *Scatterer::operator[](unsigned I) { return CV[I]; IRBuilder<> Builder(BB, BBI); if (PtrTy) { + Type *ElTy = PtrTy->getElementType()->getVectorElementType(); if (!CV[0]) { - Type *Ty = - PointerType::get(PtrTy->getElementType()->getVectorElementType(), - PtrTy->getAddressSpace()); - CV[0] = Builder.CreateBitCast(V, Ty, V->getName() + ".i0"); + Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace()); + CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); } if (I != 0) - CV[I] = Builder.CreateConstGEP1_32(nullptr, CV[0], I, + CV[I] = Builder.CreateConstGEP1_32(ElTy, CV[0], I, V->getName() + ".i" + Twine(I)); } else { // Search through a chain of InsertElementInsts looking for element I. @@ -349,7 +361,7 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType())); - transferMetadata(Op, CV); + transferMetadataAndIRFlags(Op, CV); // If we already have a scattered form of Op (created from ExtractElements // of Op itself), replace them with the new form. @@ -385,7 +397,8 @@ bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { // Transfer metadata from Op to the instructions in CV if it is known // to be safe to do so. -void ScalarizerVisitor::transferMetadata(Instruction *Op, const ValueVector &CV) { +void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, + const ValueVector &CV) { SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; Op->getAllMetadataOtherThanDebugLoc(MDs); for (unsigned I = 0, E = CV.size(); I != E; ++I) { @@ -393,6 +406,7 @@ void ScalarizerVisitor::transferMetadata(Instruction *Op, const ValueVector &CV) for (const auto &MD : MDs) if (canTransferMetadata(MD.first)) New->setMetadata(MD.first, MD.second); + New->copyIRFlags(Op); if (Op->getDebugLoc() && !New->getDebugLoc()) New->setDebugLoc(Op->getDebugLoc()); } @@ -410,8 +424,7 @@ bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment, // Check that we're dealing with full-byte elements. Layout.ElemTy = Layout.VecTy->getElementType(); - if (DL.getTypeSizeInBits(Layout.ElemTy) != - DL.getTypeStoreSizeInBits(Layout.ElemTy)) + if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy)) return false; if (Alignment) @@ -422,6 +435,26 @@ bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment, return true; } +// Scalarize one-operand instruction I, using Split(Builder, X, Name) +// to create an instruction like I with operand X and name Name. +template<typename Splitter> +bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) { + VectorType *VT = dyn_cast<VectorType>(I.getType()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + IRBuilder<> Builder(&I); + Scatterer Op = scatter(&I, I.getOperand(0)); + assert(Op.size() == NumElems && "Mismatched unary operation"); + ValueVector Res; + Res.resize(NumElems); + for (unsigned Elem = 0; Elem < NumElems; ++Elem) + Res[Elem] = Split(Builder, Op[Elem], I.getName() + ".i" + Twine(Elem)); + gather(&I, Res); + return true; +} + // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name) // to create an instruction like I with operands X and Y and name Name. template<typename Splitter> @@ -554,6 +587,10 @@ bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) { return splitBinary(FCI, FCmpSplitter(FCI)); } +bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) { + return splitUnary(UO, UnarySplitter(UO)); +} + bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) { return splitBinary(BO, BinarySplitter(BO)); } @@ -744,7 +781,8 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { Res.resize(NumElems); for (unsigned I = 0; I < NumElems; ++I) - Res[I] = Builder.CreateAlignedLoad(Ptr[I], Layout.getElemAlign(I), + Res[I] = Builder.CreateAlignedLoad(Layout.VecTy->getElementType(), Ptr[I], + Layout.getElemAlign(I), LI.getName() + ".i" + Twine(I)); gather(&LI, Res); return true; @@ -773,7 +811,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { unsigned Align = Layout.getElemAlign(I); Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align); } - transferMetadata(&SI, Stores); + transferMetadataAndIRFlags(&SI, Stores); return true; } diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 967f4a42a8fb..f6a12fb13142 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -1,9 +1,8 @@ //===- SeparateConstOffsetFromGEP.cpp -------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 5a67178cef37..aeac6f548b32 100644 --- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -1,9 +1,8 @@ ///===- SimpleLoopUnswitch.cpp - Hoist loop-invariant control flow ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -181,14 +180,9 @@ static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc) { IRBuilder<> IRB(&BB); - Value *Cond = Invariants.front(); - for (Value *Invariant : - make_range(std::next(Invariants.begin()), Invariants.end())) - if (Direction) - Cond = IRB.CreateOr(Cond, Invariant); - else - Cond = IRB.CreateAnd(Cond, Invariant); - + + Value *Cond = Direction ? IRB.CreateOr(Invariants) : + IRB.CreateAnd(Invariants); IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, Direction ? &NormalSucc : &UnswitchedSucc); } @@ -268,7 +262,8 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, /// loops reachable and need to move the current loop up the loop nest or even /// to an entirely separate nest. static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, - DominatorTree &DT, LoopInfo &LI) { + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU) { // If the loop is already at the top level, we can't hoist it anywhere. Loop *OldParentL = L.getParentLoop(); if (!OldParentL) @@ -329,7 +324,8 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, // unswitching it is possible to get new non-dedicated exits out of parent // loop so let's conservatively form dedicated exit blocks and figure out // if we can optimize later. - formDedicatedExitBlocks(OldContainingL, &DT, &LI, /*PreserveLCSSA*/ true); + formDedicatedExitBlocks(OldContainingL, &DT, &LI, MSSAU, + /*PreserveLCSSA*/ true); } } @@ -536,7 +532,10 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // If this was full unswitching, we may have changed the nesting relationship // for this loop so hoist it to its correct parent if needed. if (FullUnswitch) - hoistLoopToNewParent(L, *NewPH, DT, LI); + hoistLoopToNewParent(L, *NewPH, DT, LI, MSSAU); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); LLVM_DEBUG(dbgs() << " done: unswitching trivial branch...\n"); ++NumTrivial; @@ -590,11 +589,13 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; + SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight = + SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0); if (!L.contains(SI.getDefaultDest()) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && - !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) + !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) { DefaultExitBB = SI.getDefaultDest(); - else if (ExitCaseIndices.empty()) + } else if (ExitCaseIndices.empty()) return false; LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n"); @@ -618,8 +619,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // Store the exit cases into a separate data structure and remove them from // the switch. - SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases; + SmallVector<std::tuple<ConstantInt *, BasicBlock *, + SwitchInstProfUpdateWrapper::CaseWeightOpt>, + 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); + SwitchInstProfUpdateWrapper SIW(SI); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { @@ -629,9 +633,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, if (!ExitL || ExitL->contains(OuterL)) OuterL = ExitL; // Save the value of this case. - ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex()); + ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W); // Delete the unswitched cases. - SI.removeCase(CaseI); + SIW.removeCase(CaseI); } if (SE) { @@ -669,6 +674,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // Now add the unswitched switch. auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + SwitchInstProfUpdateWrapper NewSIW(*NewSI); // Rewrite the IR for the unswitched basic blocks. This requires two steps. // First, we split any exit blocks with remaining in-loop predecessors. Then @@ -696,9 +702,9 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, } // Note that we must use a reference in the for loop so that we update the // container. - for (auto &CasePair : reverse(ExitCases)) { + for (auto &ExitCase : reverse(ExitCases)) { // Grab a reference to the exit block in the pair so that we can update it. - BasicBlock *ExitBB = CasePair.second; + BasicBlock *ExitBB = std::get<1>(ExitCase); // If this case is the last edge into the exit block, we can simply reuse it // as it will no longer be a loop exit. No mapping necessary. @@ -720,27 +726,39 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, /*FullUnswitch*/ true); } // Update the case pair to point to the split block. - CasePair.second = SplitExitBB; + std::get<1>(ExitCase) = SplitExitBB; } // Now add the unswitched cases. We do this in reverse order as we built them // in reverse order. - for (auto CasePair : reverse(ExitCases)) { - ConstantInt *CaseVal = CasePair.first; - BasicBlock *UnswitchedBB = CasePair.second; + for (auto &ExitCase : reverse(ExitCases)) { + ConstantInt *CaseVal = std::get<0>(ExitCase); + BasicBlock *UnswitchedBB = std::get<1>(ExitCase); - NewSI->addCase(CaseVal, UnswitchedBB); + NewSIW.addCase(CaseVal, UnswitchedBB, std::get<2>(ExitCase)); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { - NewSI->setDefaultDest(DefaultExitBB); + NewSIW->setDefaultDest(DefaultExitBB); + NewSIW.setSuccessorWeight(0, DefaultCaseWeight); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. - for (auto Case : SI.cases()) - NewSI->addCase(Case.getCaseValue(), NewPH); + for (const auto &Case : SI.cases()) + NewSIW.addCase(Case.getCaseValue(), NewPH, + SIW.getSuccessorWeight(Case.getSuccessorIndex())); + } else if (DefaultCaseWeight) { + // We have to set branch weight of the default case. + uint64_t SW = *DefaultCaseWeight; + for (const auto &Case : SI.cases()) { + auto W = SIW.getSuccessorWeight(Case.getSuccessorIndex()); + assert(W && + "case weight must be defined as default case weight is defined"); + SW += *W; + } + NewSIW.setSuccessorWeight(0, SW); } // If we ended up with a common successor for every path through the switch @@ -762,10 +780,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, continue; } CommonSuccBB->removePredecessor(BB, - /*DontDeleteUselessPHIs*/ true); + /*KeepOneInputPHIs*/ true); } // Now nuke the switch and replace it with a direct branch. - SI.eraseFromParent(); + SIW.eraseFromParent(); BranchInst::Create(CommonSuccBB, BB); } else if (DefaultExitBB) { assert(SI.getNumCases() > 0 && @@ -775,8 +793,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // being simple and keeping the number of edges from this switch to // successors the same, and avoiding any PHI update complexity. auto LastCaseI = std::prev(SI.case_end()); + SI.setDefaultDest(LastCaseI->getCaseSuccessor()); - SI.removeCase(LastCaseI); + SIW.setSuccessorWeight( + 0, SIW.getSuccessorWeight(LastCaseI->getSuccessorIndex())); + SIW.removeCase(LastCaseI); } // Walk the unswitched exit blocks and the unswitched split blocks and update @@ -789,9 +810,8 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, DTUpdates.push_back({DT.Insert, OldPH, UnswitchedExitBB}); } for (auto SplitUnswitchedPair : SplitExitBBMap) { - auto *UnswitchedBB = SplitUnswitchedPair.second; - DTUpdates.push_back({DT.Delete, ParentBB, UnswitchedBB}); - DTUpdates.push_back({DT.Insert, OldPH, UnswitchedBB}); + DTUpdates.push_back({DT.Delete, ParentBB, SplitUnswitchedPair.first}); + DTUpdates.push_back({DT.Insert, OldPH, SplitUnswitchedPair.second}); } DT.applyUpdates(DTUpdates); @@ -805,7 +825,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // We may have changed the nesting relationship for this loop so hoist it to // its correct parent if needed. - hoistLoopToNewParent(L, *NewPH, DT, LI); + hoistLoopToNewParent(L, *NewPH, DT, LI, MSSAU); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); ++NumTrivial; ++NumSwitches; @@ -848,6 +871,10 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, // Check if there are any side-effecting instructions (e.g. stores, calls, // volatile loads) in the part of the loop that the code *would* execute // without unswitching. + if (MSSAU) // Possible early exit with MSSA + if (auto *Defs = MSSAU->getMemorySSA()->getBlockDefs(CurrentBB)) + if (!isa<MemoryPhi>(*Defs->begin()) || (++Defs->begin() != Defs->end())) + return Changed; if (llvm::any_of(*CurrentBB, [](Instruction &I) { return I.mayHaveSideEffects(); })) return Changed; @@ -1066,7 +1093,7 @@ static BasicBlock *buildClonedLoopBlocks( continue; ClonedSuccBB->removePredecessor(ClonedParentBB, - /*DontDeleteUselessPHIs*/ true); + /*KeepOneInputPHIs*/ true); } // Replace the cloned branch with an unconditional branch to the cloned @@ -1436,8 +1463,8 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, // Remove all MemorySSA in the dead blocks if (MSSAU) { - SmallPtrSet<BasicBlock *, 16> DeadBlockSet(DeadBlocks.begin(), - DeadBlocks.end()); + SmallSetVector<BasicBlock *, 8> DeadBlockSet(DeadBlocks.begin(), + DeadBlocks.end()); MSSAU->removeBlocks(DeadBlockSet); } @@ -1455,7 +1482,7 @@ static void deleteDeadBlocksFromLoop(Loop &L, MemorySSAUpdater *MSSAU) { // Find all the dead blocks tied to this loop, and remove them from their // successors. - SmallPtrSet<BasicBlock *, 16> DeadBlockSet; + SmallSetVector<BasicBlock *, 8> DeadBlockSet; // Start with loop/exit blocks and get a transitive closure of reachable dead // blocks. @@ -1712,10 +1739,9 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, // Sort the exits in ascending loop depth, we'll work backwards across these // to process them inside out. - std::stable_sort(ExitsInLoops.begin(), ExitsInLoops.end(), - [&](BasicBlock *LHS, BasicBlock *RHS) { - return LI.getLoopDepth(LHS) < LI.getLoopDepth(RHS); - }); + llvm::stable_sort(ExitsInLoops, [&](BasicBlock *LHS, BasicBlock *RHS) { + return LI.getLoopDepth(LHS) < LI.getLoopDepth(RHS); + }); // We'll build up a set for each exit loop. SmallPtrSet<BasicBlock *, 16> NewExitLoopBlocks; @@ -2075,7 +2101,7 @@ static void unswitchNontrivialInvariants( "Only one possible unswitched block for a branch!"); BasicBlock *UnswitchedSuccBB = *UnswitchedSuccBBs.begin(); UnswitchedSuccBB->removePredecessor(ParentBB, - /*DontDeleteUselessPHIs*/ true); + /*KeepOneInputPHIs*/ true); DTUpdates.push_back({DominatorTree::Delete, ParentBB, UnswitchedSuccBB}); } else { // Note that we actually want to remove the parent block as a predecessor @@ -2090,7 +2116,7 @@ static void unswitchNontrivialInvariants( for (auto &Case : NewSI->cases()) Case.getCaseSuccessor()->removePredecessor( ParentBB, - /*DontDeleteUselessPHIs*/ true); + /*KeepOneInputPHIs*/ true); // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to @@ -2236,7 +2262,7 @@ static void unswitchNontrivialInvariants( // introduced new, non-dedicated exits. At least try to re-form dedicated // exits for these loops. This may fail if they couldn't have dedicated // exits to start with. - formDedicatedExitBlocks(&UpdateL, &DT, &LI, /*PreserveLCSSA*/ true); + formDedicatedExitBlocks(&UpdateL, &DT, &LI, MSSAU, /*PreserveLCSSA*/ true); }; // For non-child cloned loops and hoisted loops, we just need to update LCSSA @@ -2526,7 +2552,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, // We can only consider fully loop-invariant switch conditions as we need // to completely eliminate the switch after unswitching. if (!isa<Constant>(SI->getCondition()) && - L.isLoopInvariant(SI->getCondition())) + L.isLoopInvariant(SI->getCondition()) && !BB->getUniqueSuccessor()) UnswitchCandidates.push_back({SI, {SI->getCondition()}}); continue; } @@ -2852,7 +2878,11 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); - return getLoopPassPreservedAnalyses(); + + auto PA = getLoopPassPreservedAnalyses(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); + return PA; } namespace { diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index b7b1db76b492..4544975a4887 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -1,9 +1,8 @@ //===- SimplifyCFGPass.cpp - CFG Simplification Pass ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index c99da8f0737a..90f3a2aa46e1 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -1,9 +1,8 @@ //===-- Sink.cpp - Code Sinking -------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index c0f75ddddbe0..c13fb3e04516 100644 --- a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -1,9 +1,8 @@ //===- SpeculateAroundPHIs.cpp --------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -68,6 +67,14 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, return false; } + if (auto CS = ImmutableCallSite(UI)) { + if (CS.isConvergent() || CS.cannotDuplicate()) { + LLVM_DEBUG(dbgs() << " Unsafe: convergent " + "callsite cannot de duplicated: " << *UI << '\n'); + return false; + } + } + // FIXME: This check is much too conservative. We're not going to move these // instructions onto new dynamic paths through the program unless there is // a call instruction between the use and the PHI node. And memory isn't diff --git a/lib/Transforms/Scalar/SpeculativeExecution.cpp b/lib/Transforms/Scalar/SpeculativeExecution.cpp index f5e1dd6ed850..f9d027eb4a3b 100644 --- a/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -1,9 +1,8 @@ //===- SpeculativeExecution.cpp ---------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -241,6 +240,7 @@ static unsigned ComputeSpeculationCost(const Instruction *I, case Instruction::FMul: case Instruction::FDiv: case Instruction::FRem: + case Instruction::FNeg: case Instruction::ICmp: case Instruction::FCmp: return TTI.getUserCost(I); diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index b5089b006bdd..a58c32cc5894 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -1,9 +1,8 @@ //===- StraightLineStrengthReduce.cpp - -----------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -683,9 +682,13 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( // Canonicalize bump to pointer size. Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); if (InBounds) - Reduced = Builder.CreateInBoundsGEP(nullptr, Basis.Ins, Bump); + Reduced = Builder.CreateInBoundsGEP( + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), + Basis.Ins, Bump); else - Reduced = Builder.CreateGEP(nullptr, Basis.Ins, Bump); + Reduced = Builder.CreateGEP( + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), + Basis.Ins, Bump); } break; } diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index 0db762d846f2..e5400676c7e8 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -1,9 +1,8 @@ //===- StructurizeCFG.cpp -------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -63,6 +62,11 @@ static cl::opt<bool> ForceSkipUniformRegions( cl::desc("Force whether the StructurizeCFG pass skips uniform regions"), cl::init(false)); +static cl::opt<bool> + RelaxedUniformRegions("structurizecfg-relaxed-uniform-regions", cl::Hidden, + cl::desc("Allow relaxed uniform region checks"), + cl::init(false)); + // Definition of the complex types used in this pass. using BBValuePair = std::pair<BasicBlock *, Value *>; @@ -624,11 +628,8 @@ void StructurizeCFG::setPhiValues() { if (!Dominator.resultIsRememberedBlock()) Updater.AddAvailableValue(Dominator.result(), Undef); - for (BasicBlock *FI : From) { - int Idx = Phi->getBasicBlockIndex(FI); - assert(Idx != -1); - Phi->setIncomingValue(Idx, Updater.GetValueAtEndOfBlock(FI)); - } + for (BasicBlock *FI : From) + Phi->setIncomingValueForBlock(FI, Updater.GetValueAtEndOfBlock(FI)); } DeletedPhis.erase(To); @@ -937,6 +938,11 @@ void StructurizeCFG::rebuildSSA() { static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, const LegacyDivergenceAnalysis &DA) { + // Bool for if all sub-regions are uniform. + bool SubRegionsAreUniform = true; + // Count of how many direct children are conditional. + unsigned ConditionalDirectChildren = 0; + for (auto E : R->elements()) { if (!E->isSubRegion()) { auto Br = dyn_cast<BranchInst>(E->getEntry()->getTerminator()); @@ -945,6 +951,10 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, if (!DA.isUniform(Br)) return false; + + // One of our direct children is conditional. + ConditionalDirectChildren++; + LLVM_DEBUG(dbgs() << "BB: " << Br->getParent()->getName() << " has uniform terminator\n"); } else { @@ -962,12 +972,25 @@ static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, if (!Br || !Br->isConditional()) continue; - if (!Br->getMetadata(UniformMDKindID)) - return false; + if (!Br->getMetadata(UniformMDKindID)) { + // Early exit if we cannot have relaxed uniform regions. + if (!RelaxedUniformRegions) + return false; + + SubRegionsAreUniform = false; + break; + } } } } - return true; + + // Our region is uniform if: + // 1. All conditional branches that are direct children are uniform (checked + // above). + // 2. And either: + // a. All sub-regions are uniform. + // b. There is one or less conditional branches among the direct children. + return SubRegionsAreUniform || (ConditionalDirectChildren <= 1); } /// Run the transformation for each region found diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index 0f6db21f73b6..f0b79079d817 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -1,9 +1,8 @@ //===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -56,6 +55,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -69,7 +69,6 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -341,7 +340,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { // being loaded from. const DataLayout &DL = L->getModule()->getDataLayout(); if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || - !isSafeToLoadUnconditionally(L->getPointerOperand(), + !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), L->getAlignment(), DL, L)) return false; } @@ -679,7 +678,7 @@ static bool eliminateRecursiveTailCall( BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. - DTU.insertEdge(BB, OldEntry); + DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}}); ++NumEliminated; return true; } diff --git a/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/lib/Transforms/Scalar/WarnMissedTransforms.cpp index 80f761e53774..707adf46d1f4 100644 --- a/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -1,9 +1,8 @@ //===- LoopTransformWarning.cpp - ----------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -93,7 +92,7 @@ PreservedAnalyses WarnMissedTransformationsPass::run(Function &F, FunctionAnalysisManager &AM) { // Do not warn about not applied transformations if optimizations are // disabled. - if (F.hasFnAttribute(Attribute::OptimizeNone)) + if (F.hasOptNone()) return PreservedAnalyses::all(); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); diff --git a/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/lib/Transforms/Utils/ASanStackFrameLayout.cpp index 364878dc588d..01912297324a 100644 --- a/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -1,9 +1,8 @@ //===-- ASanStackFrameLayout.cpp - helper for AddressSanitizer ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -63,7 +62,7 @@ ComputeASanStackFrameLayout(SmallVectorImpl<ASanStackVariableDescription> &Vars, for (size_t i = 0; i < NumVars; i++) Vars[i].Alignment = std::max(Vars[i].Alignment, kMinAlignment); - std::stable_sort(Vars.begin(), Vars.end(), CompareVars); + llvm::stable_sort(Vars, CompareVars); ASanStackFrameLayout Layout; Layout.Granularity = Granularity; diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 564537af0c2a..ee0973002c47 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -1,9 +1,8 @@ //===- AddDiscriminators.cpp - Insert DWARF path discriminators -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -209,7 +208,7 @@ static bool addDiscriminators(Function &F) { // Only the lowest 7 bits are used to represent a discriminator to fit // it in 1 byte ULEB128 representation. unsigned Discriminator = R.second ? ++LDM[L] : LDM[L]; - auto NewDIL = DIL->setBaseDiscriminator(Discriminator); + auto NewDIL = DIL->cloneWithBaseDiscriminator(Discriminator); if (!NewDIL) { LLVM_DEBUG(dbgs() << "Could not encode discriminator: " << DIL->getFilename() << ":" << DIL->getLine() << ":" @@ -246,7 +245,7 @@ static bool addDiscriminators(Function &F) { std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine()); if (!CallLocations.insert(L).second) { unsigned Discriminator = ++LDM[L]; - auto NewDIL = CurrentDIL->setBaseDiscriminator(Discriminator); + auto NewDIL = CurrentDIL->cloneWithBaseDiscriminator(Discriminator); if (!NewDIL) { LLVM_DEBUG(dbgs() << "Could not encode discriminator: " diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 7da768252fc1..5fa371377c85 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -1,9 +1,8 @@ //===- BasicBlockUtils.cpp - BasicBlock Utilities --------------------------==// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -18,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -26,7 +26,6 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -39,6 +38,8 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -48,30 +49,20 @@ using namespace llvm; -void llvm::DeleteDeadBlock(BasicBlock *BB, DomTreeUpdater *DTU) { - SmallVector<BasicBlock *, 1> BBs = {BB}; - DeleteDeadBlocks(BBs, DTU); -} - -void llvm::DeleteDeadBlocks(SmallVectorImpl <BasicBlock *> &BBs, - DomTreeUpdater *DTU) { -#ifndef NDEBUG - // Make sure that all predecessors of each dead block is also dead. - SmallPtrSet<BasicBlock *, 4> Dead(BBs.begin(), BBs.end()); - assert(Dead.size() == BBs.size() && "Duplicating blocks?"); - for (auto *BB : Dead) - for (BasicBlock *Pred : predecessors(BB)) - assert(Dead.count(Pred) && "All predecessors must be dead!"); -#endif +#define DEBUG_TYPE "basicblock-utils" - SmallVector<DominatorTree::UpdateType, 4> Updates; +void llvm::DetatchDeadBlocks( + ArrayRef<BasicBlock *> BBs, + SmallVectorImpl<DominatorTree::UpdateType> *Updates, + bool KeepOneInputPHIs) { for (auto *BB : BBs) { // Loop through all of our successors and make sure they know that one // of their predecessors is going away. + SmallPtrSet<BasicBlock *, 4> UniqueSuccessors; for (BasicBlock *Succ : successors(BB)) { - Succ->removePredecessor(BB); - if (DTU) - Updates.push_back({DominatorTree::Delete, BB, Succ}); + Succ->removePredecessor(BB, KeepOneInputPHIs); + if (Updates && UniqueSuccessors.insert(Succ).second) + Updates->push_back({DominatorTree::Delete, BB, Succ}); } // Zap all the instructions in the block. @@ -92,8 +83,29 @@ void llvm::DeleteDeadBlocks(SmallVectorImpl <BasicBlock *> &BBs, "The successor list of BB isn't empty before " "applying corresponding DTU updates."); } +} + +void llvm::DeleteDeadBlock(BasicBlock *BB, DomTreeUpdater *DTU, + bool KeepOneInputPHIs) { + DeleteDeadBlocks({BB}, DTU, KeepOneInputPHIs); +} + +void llvm::DeleteDeadBlocks(ArrayRef <BasicBlock *> BBs, DomTreeUpdater *DTU, + bool KeepOneInputPHIs) { +#ifndef NDEBUG + // Make sure that all predecessors of each dead block is also dead. + SmallPtrSet<BasicBlock *, 4> Dead(BBs.begin(), BBs.end()); + assert(Dead.size() == BBs.size() && "Duplicating blocks?"); + for (auto *BB : Dead) + for (BasicBlock *Pred : predecessors(BB)) + assert(Dead.count(Pred) && "All predecessors must be dead!"); +#endif + + SmallVector<DominatorTree::UpdateType, 4> Updates; + DetatchDeadBlocks(BBs, DTU ? &Updates : nullptr, KeepOneInputPHIs); + if (DTU) - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); for (BasicBlock *BB : BBs) if (DTU) @@ -102,6 +114,28 @@ void llvm::DeleteDeadBlocks(SmallVectorImpl <BasicBlock *> &BBs, BB->eraseFromParent(); } +bool llvm::EliminateUnreachableBlocks(Function &F, DomTreeUpdater *DTU, + bool KeepOneInputPHIs) { + df_iterator_default_set<BasicBlock*> Reachable; + + // Mark all reachable blocks. + for (BasicBlock *BB : depth_first_ext(&F, Reachable)) + (void)BB/* Mark all reachable blocks */; + + // Collect all dead blocks. + std::vector<BasicBlock*> DeadBlocks; + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) + if (!Reachable.count(&*I)) { + BasicBlock *BB = &*I; + DeadBlocks.push_back(BB); + } + + // Delete the dead blocks. + DeleteDeadBlocks(DeadBlocks, DTU, KeepOneInputPHIs); + + return !DeadBlocks.empty(); +} + void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, MemoryDependenceResults *MemDep) { if (!isa<PHINode>(BB->begin())) return; @@ -160,6 +194,9 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (IncValue == &PN) return false; + LLVM_DEBUG(dbgs() << "Merging: " << BB->getName() << " into " + << PredBB->getName() << "\n"); + // Begin by getting rid of unneeded PHIs. SmallVector<AssertingVH<Value>, 4> IncomingValues; if (isa<PHINode>(BB->front())) { @@ -175,11 +212,19 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, std::vector<DominatorTree::UpdateType> Updates; if (DTU) { Updates.reserve(1 + (2 * succ_size(BB))); - Updates.push_back({DominatorTree::Delete, PredBB, BB}); - for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + // Add insert edges first. Experimentally, for the particular case of two + // blocks that can be merged, with a single successor and single predecessor + // respectively, it is beneficial to have all insert updates first. Deleting + // edges first may lead to unreachable blocks, followed by inserting edges + // making the blocks reachable again. Such DT updates lead to high compile + // times. We add inserts before deletes here to reduce compile time. + for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) + // This successor of BB may already have PredBB as a predecessor. + if (llvm::find(successors(PredBB), *I) == succ_end(PredBB)) + Updates.push_back({DominatorTree::Insert, PredBB, *I}); + for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) Updates.push_back({DominatorTree::Delete, BB, *I}); - Updates.push_back({DominatorTree::Insert, PredBB, *I}); - } + Updates.push_back({DominatorTree::Delete, PredBB, BB}); } if (MSSAU) @@ -227,7 +272,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, isa<UnreachableInst>(BB->getTerminator()) && "The successor list of BB isn't empty before " "applying corresponding DTU updates."); - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); DTU->deleteBB(BB); } @@ -534,7 +579,13 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // The new block unconditionally branches to the old block. BranchInst *BI = BranchInst::Create(BB, NewBB); - BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc()); + // Splitting the predecessors of a loop header creates a preheader block. + if (LI && LI->isLoopHeader(BB)) + // Using the loop start line number prevents debuggers stepping into the + // loop body for this instruction. + BI->setDebugLoc(LI->getLoopFor(BB)->getStartLoc()); + else + BI->setDebugLoc(BB->getFirstNonPHIOrDbg()->getDebugLoc()); // Move the edges from Preds to point to NewBB instead of BB. for (unsigned i = 0, e = Preds.size(); i != e; ++i) { @@ -543,6 +594,8 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // all BlockAddress uses would need to be updated. assert(!isa<IndirectBrInst>(Preds[i]->getTerminator()) && "Cannot split an edge from an IndirectBrInst"); + assert(!isa<CallBrInst>(Preds[i]->getTerminator()) && + "Cannot split an edge from a CallBrInst"); Preds[i]->getTerminator()->replaceUsesOfWith(BB, NewBB); } @@ -711,7 +764,7 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, UncondBranch->eraseFromParent(); if (DTU) - DTU->deleteEdge(Pred, BB); + DTU->applyUpdates({{DominatorTree::Delete, Pred, BB}}); return cast<ReturnInst>(NewRet); } @@ -720,18 +773,23 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, bool Unreachable, MDNode *BranchWeights, - DominatorTree *DT, LoopInfo *LI) { + DominatorTree *DT, LoopInfo *LI, + BasicBlock *ThenBlock) { BasicBlock *Head = SplitBefore->getParent(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); Instruction *HeadOldTerm = Head->getTerminator(); LLVMContext &C = Head->getContext(); - BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); Instruction *CheckTerm; - if (Unreachable) - CheckTerm = new UnreachableInst(C, ThenBlock); - else - CheckTerm = BranchInst::Create(Tail, ThenBlock); - CheckTerm->setDebugLoc(SplitBefore->getDebugLoc()); + bool CreateThenBlock = (ThenBlock == nullptr); + if (CreateThenBlock) { + ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); + if (Unreachable) + CheckTerm = new UnreachableInst(C, ThenBlock); + else + CheckTerm = BranchInst::Create(Tail, ThenBlock); + CheckTerm->setDebugLoc(SplitBefore->getDebugLoc()); + } else + CheckTerm = ThenBlock->getTerminator(); BranchInst *HeadNewTerm = BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/Tail, Cond); HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); @@ -746,7 +804,10 @@ Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, DT->changeImmediateDominator(Child, NewNode); // Head dominates ThenBlock. - DT->addNewBlock(ThenBlock, Head); + if (CreateThenBlock) + DT->addNewBlock(ThenBlock, Head); + else + DT->changeImmediateDominator(ThenBlock, Head); } } diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index fafc9aaba5c9..f5e4b53f6d97 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -1,9 +1,8 @@ //===- BreakCriticalEdges.cpp - Critical Edge Elimination Pass ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -24,6 +23,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -49,10 +49,14 @@ namespace { bool runOnFunction(Function &F) override { auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + + auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); + auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; unsigned N = - SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI)); + SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT)); NumBroken += N; return N > 0; } @@ -145,6 +149,14 @@ llvm::SplitCriticalEdge(Instruction *TI, unsigned SuccNum, // it in this generic function. if (DestBB->isEHPad()) return nullptr; + // Don't split the non-fallthrough edge from a callbr. + if (isa<CallBrInst>(TI) && SuccNum > 0) + return nullptr; + + if (Options.IgnoreUnreachableDests && + isa<UnreachableInst>(DestBB->getFirstNonPHIOrDbgOrLifetime())) + return nullptr; + // Create a new basic block, linking it into the CFG. BasicBlock *NewBB = BasicBlock::Create(TI->getContext(), TIBB->getName() + "." + DestBB->getName() + "_crit_edge"); @@ -189,7 +201,7 @@ llvm::SplitCriticalEdge(Instruction *TI, unsigned SuccNum, if (TI->getSuccessor(i) != DestBB) continue; // Remove an entry for TIBB from DestBB phi nodes. - DestBB->removePredecessor(TIBB, Options.DontDeleteUselessPHIs); + DestBB->removePredecessor(TIBB, Options.KeepOneInputPHIs); // We found another edge to DestBB, go to NewBB instead. TI->setSuccessor(i, NewBB); @@ -198,16 +210,17 @@ llvm::SplitCriticalEdge(Instruction *TI, unsigned SuccNum, // If we have nothing to update, just return. auto *DT = Options.DT; + auto *PDT = Options.PDT; auto *LI = Options.LI; auto *MSSAU = Options.MSSAU; if (MSSAU) MSSAU->wireOldPredecessorsToNewImmediatePredecessor( DestBB, NewBB, {TIBB}, Options.MergeIdenticalEdges); - if (!DT && !LI) + if (!DT && !PDT && !LI) return NewBB; - if (DT) { + if (DT || PDT) { // Update the DominatorTree. // ---> NewBB -----\ // / V @@ -223,7 +236,10 @@ llvm::SplitCriticalEdge(Instruction *TI, unsigned SuccNum, if (llvm::find(successors(TIBB), DestBB) == succ_end(TIBB)) Updates.push_back({DominatorTree::Delete, TIBB, DestBB}); - DT->applyUpdates(Updates); + if (DT) + DT->applyUpdates(Updates); + if (PDT) + PDT->applyUpdates(Updates); } // Update LoopInfo if it is around. diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index 3466dedd3236..27f110e24f9c 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -1,9 +1,8 @@ //===- BuildLibCalls.cpp - Utility builder for libcalls -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -23,6 +22,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/Analysis/MemoryBuiltins.h" using namespace llvm; @@ -121,6 +121,13 @@ static bool setNonLazyBind(Function &F) { return true; } +static bool setDoesNotFreeMemory(Function &F) { + if (F.hasFnAttribute(Attribute::NoFree)) + return false; + F.addFnAttr(Attribute::NoFree); + return true; +} + bool llvm::inferLibFuncAttributes(Module *M, StringRef Name, const TargetLibraryInfo &TLI) { Function *F = M->getFunction(Name); @@ -136,6 +143,9 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { bool Changed = false; + if(!isLibFreeFunction(&F, TheLibFunc) && !isReallocLikeFn(&F, &TLI)) + Changed |= setDoesNotFreeMemory(F); + if (F.getParent() != nullptr && F.getParent()->getRtLibUseGOT()) Changed |= setNonLazyBind(F); @@ -790,95 +800,76 @@ Value *llvm::castToCStr(Value *V, IRBuilder<> &B) { return B.CreateBitCast(V, B.getInt8PtrTy(AS), "cstr"); } -Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, - const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_strlen)) +static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType, + ArrayRef<Type *> ParamTypes, + ArrayRef<Value *> Operands, IRBuilder<> &B, + const TargetLibraryInfo *TLI, + bool IsVaArgs = false) { + if (!TLI->has(TheLibFunc)) return nullptr; Module *M = B.GetInsertBlock()->getModule(); - StringRef StrlenName = TLI->getName(LibFunc_strlen); - LLVMContext &Context = B.GetInsertBlock()->getContext(); - Constant *StrLen = M->getOrInsertFunction(StrlenName, DL.getIntPtrType(Context), - B.getInt8PtrTy()); - inferLibFuncAttributes(M, StrlenName, *TLI); - CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), StrlenName); - if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) + StringRef FuncName = TLI->getName(TheLibFunc); + FunctionType *FuncType = FunctionType::get(ReturnType, ParamTypes, IsVaArgs); + FunctionCallee Callee = M->getOrInsertFunction(FuncName, FuncType); + inferLibFuncAttributes(M, FuncName, *TLI); + CallInst *CI = B.CreateCall(Callee, Operands, FuncName); + if (const Function *F = + dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); - return CI; } -Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, +Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_strchr)) - return nullptr; + LLVMContext &Context = B.GetInsertBlock()->getContext(); + return emitLibCall(LibFunc_strlen, DL.getIntPtrType(Context), + B.getInt8PtrTy(), castToCStr(Ptr, B), B, TLI); +} - Module *M = B.GetInsertBlock()->getModule(); - StringRef StrChrName = TLI->getName(LibFunc_strchr); +Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); - Constant *StrChr = - M->getOrInsertFunction(StrChrName, I8Ptr, I8Ptr, I32Ty); - inferLibFuncAttributes(M, StrChrName, *TLI); - CallInst *CI = B.CreateCall( - StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, StrChrName); - if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - return CI; + return emitLibCall(LibFunc_strchr, I8Ptr, {I8Ptr, I32Ty}, + {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, B, TLI); } Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_strncmp)) - return nullptr; - - Module *M = B.GetInsertBlock()->getModule(); - StringRef StrNCmpName = TLI->getName(LibFunc_strncmp); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *StrNCmp = M->getOrInsertFunction(StrNCmpName, B.getInt32Ty(), - B.getInt8PtrTy(), B.getInt8PtrTy(), - DL.getIntPtrType(Context)); - inferLibFuncAttributes(M, StrNCmpName, *TLI); - CallInst *CI = B.CreateCall( - StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, StrNCmpName); - - if (const Function *F = dyn_cast<Function>(StrNCmp->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - - return CI; + return emitLibCall( + LibFunc_strncmp, B.getInt32Ty(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, + {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); } Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, - const TargetLibraryInfo *TLI, StringRef Name) { - if (!TLI->has(LibFunc_strcpy)) - return nullptr; + const TargetLibraryInfo *TLI) { + Type *I8Ptr = B.getInt8PtrTy(); + return emitLibCall(LibFunc_strcpy, I8Ptr, {I8Ptr, I8Ptr}, + {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI); +} - Module *M = B.GetInsertBlock()->getModule(); +Value *llvm::emitStpCpy(Value *Dst, Value *Src, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); - Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr); - inferLibFuncAttributes(M, Name, *TLI); - CallInst *CI = - B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); - if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - return CI; + return emitLibCall(LibFunc_stpcpy, I8Ptr, {I8Ptr, I8Ptr}, + {castToCStr(Dst, B), castToCStr(Src, B)}, B, TLI); } Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, - const TargetLibraryInfo *TLI, StringRef Name) { - if (!TLI->has(LibFunc_strncpy)) - return nullptr; + const TargetLibraryInfo *TLI) { + Type *I8Ptr = B.getInt8PtrTy(); + return emitLibCall(LibFunc_strncpy, I8Ptr, {I8Ptr, I8Ptr, Len->getType()}, + {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); +} - Module *M = B.GetInsertBlock()->getModule(); +Value *llvm::emitStpNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { Type *I8Ptr = B.getInt8PtrTy(); - Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, - Len->getType()); - inferLibFuncAttributes(M, Name, *TLI); - CallInst *CI = B.CreateCall( - StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, Name); - if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - return CI; + return emitLibCall(LibFunc_stpncpy, I8Ptr, {I8Ptr, I8Ptr, Len->getType()}, + {castToCStr(Dst, B), castToCStr(Src, B), Len}, B, TLI); } Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, @@ -892,57 +883,115 @@ Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemCpy = M->getOrInsertFunction( + FunctionCallee MemCpy = M->getOrInsertFunction( "__memcpy_chk", AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context)); Dst = castToCStr(Dst, B); Src = castToCStr(Src, B); CallInst *CI = B.CreateCall(MemCpy, {Dst, Src, Len, ObjSize}); - if (const Function *F = dyn_cast<Function>(MemCpy->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(MemCpy.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_memchr)) - return nullptr; - - Module *M = B.GetInsertBlock()->getModule(); - StringRef MemChrName = TLI->getName(LibFunc_memchr); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemChr = M->getOrInsertFunction(MemChrName, B.getInt8PtrTy(), - B.getInt8PtrTy(), B.getInt32Ty(), - DL.getIntPtrType(Context)); - inferLibFuncAttributes(M, MemChrName, *TLI); - CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, MemChrName); - - if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - - return CI; + return emitLibCall( + LibFunc_memchr, B.getInt8PtrTy(), + {B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context)}, + {castToCStr(Ptr, B), Val, Len}, B, TLI); } Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_memcmp)) - return nullptr; + LLVMContext &Context = B.GetInsertBlock()->getContext(); + return emitLibCall( + LibFunc_memcmp, B.getInt32Ty(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, + {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); +} - Module *M = B.GetInsertBlock()->getModule(); - StringRef MemCmpName = TLI->getName(LibFunc_memcmp); +Value *llvm::emitBCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, + const DataLayout &DL, const TargetLibraryInfo *TLI) { LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemCmp = M->getOrInsertFunction(MemCmpName, B.getInt32Ty(), - B.getInt8PtrTy(), B.getInt8PtrTy(), - DL.getIntPtrType(Context)); - inferLibFuncAttributes(M, MemCmpName, *TLI); - CallInst *CI = B.CreateCall( - MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, MemCmpName); - - if (const Function *F = dyn_cast<Function>(MemCmp->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); + return emitLibCall( + LibFunc_bcmp, B.getInt32Ty(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)}, + {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, B, TLI); +} - return CI; +Value *llvm::emitMemCCpy(Value *Ptr1, Value *Ptr2, Value *Val, Value *Len, + IRBuilder<> &B, const TargetLibraryInfo *TLI) { + return emitLibCall( + LibFunc_memccpy, B.getInt8PtrTy(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), Len->getType()}, + {Ptr1, Ptr2, Val, Len}, B, TLI); +} + +Value *llvm::emitSNPrintf(Value *Dest, Value *Size, Value *Fmt, + ArrayRef<Value *> VariadicArgs, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + SmallVector<Value *, 8> Args{castToCStr(Dest, B), Size, castToCStr(Fmt, B)}; + Args.insert(Args.end(), VariadicArgs.begin(), VariadicArgs.end()); + return emitLibCall(LibFunc_snprintf, B.getInt32Ty(), + {B.getInt8PtrTy(), Size->getType(), B.getInt8PtrTy()}, + Args, B, TLI, /*IsVaArgs=*/true); +} + +Value *llvm::emitSPrintf(Value *Dest, Value *Fmt, + ArrayRef<Value *> VariadicArgs, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + SmallVector<Value *, 8> Args{castToCStr(Dest, B), castToCStr(Fmt, B)}; + Args.insert(Args.end(), VariadicArgs.begin(), VariadicArgs.end()); + return emitLibCall(LibFunc_sprintf, B.getInt32Ty(), + {B.getInt8PtrTy(), B.getInt8PtrTy()}, Args, B, TLI, + /*IsVaArgs=*/true); +} + +Value *llvm::emitStrCat(Value *Dest, Value *Src, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_strcat, B.getInt8PtrTy(), + {B.getInt8PtrTy(), B.getInt8PtrTy()}, + {castToCStr(Dest, B), castToCStr(Src, B)}, B, TLI); +} + +Value *llvm::emitStrLCpy(Value *Dest, Value *Src, Value *Size, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_strlcpy, Size->getType(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); +} + +Value *llvm::emitStrLCat(Value *Dest, Value *Src, Value *Size, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_strlcat, Size->getType(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); +} + +Value *llvm::emitStrNCat(Value *Dest, Value *Src, Value *Size, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_strncat, B.getInt8PtrTy(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), Size->getType()}, + {castToCStr(Dest, B), castToCStr(Src, B), Size}, B, TLI); +} + +Value *llvm::emitVSNPrintf(Value *Dest, Value *Size, Value *Fmt, Value *VAList, + IRBuilder<> &B, const TargetLibraryInfo *TLI) { + return emitLibCall( + LibFunc_vsnprintf, B.getInt32Ty(), + {B.getInt8PtrTy(), Size->getType(), B.getInt8PtrTy(), VAList->getType()}, + {castToCStr(Dest, B), Size, castToCStr(Fmt, B), VAList}, B, TLI); +} + +Value *llvm::emitVSPrintf(Value *Dest, Value *Fmt, Value *VAList, + IRBuilder<> &B, const TargetLibraryInfo *TLI) { + return emitLibCall(LibFunc_vsprintf, B.getInt32Ty(), + {B.getInt8PtrTy(), B.getInt8PtrTy(), VAList->getType()}, + {castToCStr(Dest, B), castToCStr(Fmt, B), VAList}, B, TLI); } /// Append a suffix to the function name according to the type of 'Op'. @@ -966,8 +1015,8 @@ static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, assert((Name != "") && "Must specify Name to emitUnaryFloatFnCall"); Module *M = B.GetInsertBlock()->getModule(); - Value *Callee = M->getOrInsertFunction(Name, Op->getType(), - Op->getType()); + FunctionCallee Callee = + M->getOrInsertFunction(Name, Op->getType(), Op->getType()); CallInst *CI = B.CreateCall(Callee, Op, Name); // The incoming attribute set may have come from a speculatable intrinsic, but @@ -976,7 +1025,8 @@ static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, CI->setAttributes(Attrs.removeAttribute(B.getContext(), AttributeList::FunctionIndex, Attribute::Speculatable)); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -1009,11 +1059,12 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, appendTypeSuffix(Op1, Name, NameBuffer); Module *M = B.GetInsertBlock()->getModule(); - Value *Callee = M->getOrInsertFunction(Name, Op1->getType(), Op1->getType(), - Op2->getType()); + FunctionCallee Callee = M->getOrInsertFunction( + Name, Op1->getType(), Op1->getType(), Op2->getType()); CallInst *CI = B.CreateCall(Callee, {Op1, Op2}, Name); CI->setAttributes(Attrs); - if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -1026,7 +1077,8 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef PutCharName = TLI->getName(LibFunc_putchar); - Value *PutChar = M->getOrInsertFunction(PutCharName, B.getInt32Ty(), B.getInt32Ty()); + FunctionCallee PutChar = + M->getOrInsertFunction(PutCharName, B.getInt32Ty(), B.getInt32Ty()); inferLibFuncAttributes(M, PutCharName, *TLI); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, @@ -1035,7 +1087,8 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, "chari"), PutCharName); - if (const Function *F = dyn_cast<Function>(PutChar->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(PutChar.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } @@ -1047,11 +1100,12 @@ Value *llvm::emitPutS(Value *Str, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef PutsName = TLI->getName(LibFunc_puts); - Value *PutS = + FunctionCallee PutS = M->getOrInsertFunction(PutsName, B.getInt32Ty(), B.getInt8PtrTy()); inferLibFuncAttributes(M, PutsName, *TLI); CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName); - if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(PutS.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; } @@ -1063,15 +1117,16 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef FPutcName = TLI->getName(LibFunc_fputc); - Constant *F = M->getOrInsertFunction(FPutcName, B.getInt32Ty(), B.getInt32Ty(), - File->getType()); + FunctionCallee F = M->getOrInsertFunction(FPutcName, B.getInt32Ty(), + B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(M, FPutcName, *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall(F, {Char, File}, FPutcName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1083,14 +1138,15 @@ Value *llvm::emitFPutCUnlocked(Value *Char, Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef FPutcUnlockedName = TLI->getName(LibFunc_fputc_unlocked); - Constant *F = M->getOrInsertFunction(FPutcUnlockedName, B.getInt32Ty(), - B.getInt32Ty(), File->getType()); + FunctionCallee F = M->getOrInsertFunction(FPutcUnlockedName, B.getInt32Ty(), + B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(M, FPutcUnlockedName, *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/ true, "chari"); CallInst *CI = B.CreateCall(F, {Char, File}, FPutcUnlockedName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1102,13 +1158,14 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef FPutsName = TLI->getName(LibFunc_fputs); - Constant *F = M->getOrInsertFunction( - FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType()); + FunctionCallee F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), + B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(M, FPutsName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1120,13 +1177,14 @@ Value *llvm::emitFPutSUnlocked(Value *Str, Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef FPutsUnlockedName = TLI->getName(LibFunc_fputs_unlocked); - Constant *F = M->getOrInsertFunction(FPutsUnlockedName, B.getInt32Ty(), - B.getInt8PtrTy(), File->getType()); + FunctionCallee F = M->getOrInsertFunction(FPutsUnlockedName, B.getInt32Ty(), + B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(M, FPutsUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsUnlockedName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1139,7 +1197,7 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); StringRef FWriteName = TLI->getName(LibFunc_fwrite); - Constant *F = M->getOrInsertFunction( + FunctionCallee F = M->getOrInsertFunction( FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); @@ -1149,7 +1207,8 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, B.CreateCall(F, {castToCStr(Ptr, B), Size, ConstantInt::get(DL.getIntPtrType(Context), 1), File}); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1162,12 +1221,13 @@ Value *llvm::emitMalloc(Value *Num, IRBuilder<> &B, const DataLayout &DL, Module *M = B.GetInsertBlock()->getModule(); StringRef MallocName = TLI->getName(LibFunc_malloc); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *Malloc = M->getOrInsertFunction(MallocName, B.getInt8PtrTy(), - DL.getIntPtrType(Context)); + FunctionCallee Malloc = M->getOrInsertFunction(MallocName, B.getInt8PtrTy(), + DL.getIntPtrType(Context)); inferLibFuncAttributes(M, MallocName, *TLI); CallInst *CI = B.CreateCall(Malloc, Num, MallocName); - if (const Function *F = dyn_cast<Function>(Malloc->stripPointerCasts())) + if (const Function *F = + dyn_cast<Function>(Malloc.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -1182,12 +1242,13 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, StringRef CallocName = TLI.getName(LibFunc_calloc); const DataLayout &DL = M->getDataLayout(); IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); - Value *Calloc = M->getOrInsertFunction(CallocName, Attrs, B.getInt8PtrTy(), - PtrType, PtrType); + FunctionCallee Calloc = M->getOrInsertFunction( + CallocName, Attrs, B.getInt8PtrTy(), PtrType, PtrType); inferLibFuncAttributes(M, CallocName, TLI); CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); - if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) + if (const auto *F = + dyn_cast<Function>(Calloc.getCallee()->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -1202,7 +1263,7 @@ Value *llvm::emitFWriteUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); StringRef FWriteUnlockedName = TLI->getName(LibFunc_fwrite_unlocked); - Constant *F = M->getOrInsertFunction( + FunctionCallee F = M->getOrInsertFunction( FWriteUnlockedName, DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); @@ -1210,7 +1271,8 @@ Value *llvm::emitFWriteUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, inferLibFuncAttributes(M, FWriteUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1222,13 +1284,14 @@ Value *llvm::emitFGetCUnlocked(Value *File, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); StringRef FGetCUnlockedName = TLI->getName(LibFunc_fgetc_unlocked); - Constant *F = - M->getOrInsertFunction(FGetCUnlockedName, B.getInt32Ty(), File->getType()); + FunctionCallee F = M->getOrInsertFunction(FGetCUnlockedName, B.getInt32Ty(), + File->getType()); if (File->getType()->isPointerTy()) inferLibFuncAttributes(M, FGetCUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, File, FGetCUnlockedName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1240,14 +1303,15 @@ Value *llvm::emitFGetSUnlocked(Value *Str, Value *Size, Value *File, Module *M = B.GetInsertBlock()->getModule(); StringRef FGetSUnlockedName = TLI->getName(LibFunc_fgets_unlocked); - Constant *F = + FunctionCallee F = M->getOrInsertFunction(FGetSUnlockedName, B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), File->getType()); inferLibFuncAttributes(M, FGetSUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), Size, File}, FGetSUnlockedName); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } @@ -1261,7 +1325,7 @@ Value *llvm::emitFReadUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); StringRef FReadUnlockedName = TLI->getName(LibFunc_fread_unlocked); - Constant *F = M->getOrInsertFunction( + FunctionCallee F = M->getOrInsertFunction( FReadUnlockedName, DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); @@ -1269,7 +1333,8 @@ Value *llvm::emitFReadUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, inferLibFuncAttributes(M, FReadUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); - if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + if (const Function *Fn = + dyn_cast<Function>(F.getCallee()->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); return CI; } diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index e7828af648a9..df299f673f65 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -1,9 +1,8 @@ //===- BypassSlowDivision.cpp - Bypass slow division ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/CallPromotionUtils.cpp b/lib/Transforms/Utils/CallPromotionUtils.cpp index e58ddcf34667..f04d76e70c0d 100644 --- a/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -1,9 +1,8 @@ //===- CallPromotionUtils.cpp - Utilities for call promotion ----*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -367,8 +366,9 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee, CastInst **RetBitCast) { assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); - // Set the called function of the call site to be the given callee. - CS.setCalledFunction(Callee); + // Set the called function of the call site to be the given callee (but don't + // change the type). + cast<CallBase>(CS.getInstruction())->setCalledOperand(Callee); // Since the call site will no longer be direct, we must clear metadata that // is only appropriate for indirect calls. This includes !prof and !callees @@ -412,6 +412,15 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee, // Remove any incompatible attributes for the argument. AttrBuilder ArgAttrs(CallerPAL.getParamAttributes(ArgNo)); ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy)); + + // If byval is used, this must be a pointer type, and the byval type must + // match the element type. Update it if present. + if (ArgAttrs.getByValType()) { + Type *NewTy = Callee->getParamByValType(ArgNo); + ArgAttrs.addByValAttr( + NewTy ? NewTy : cast<PointerType>(FormalTy)->getElementType()); + } + NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs)); AttributeChanged = true; } else diff --git a/lib/Transforms/Utils/CanonicalizeAliases.cpp b/lib/Transforms/Utils/CanonicalizeAliases.cpp index cf41fd2e14c0..455fcbb1cf98 100644 --- a/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -1,9 +1,8 @@ //===- CanonicalizeAliases.cpp - ThinLTO Support: Canonicalize Aliases ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 8f8c601f5f13..1026c9d37038 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -1,9 +1,8 @@ //===- CloneFunction.cpp - Clone a function into another function ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -16,13 +15,13 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" @@ -740,12 +739,12 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, const Twine &NameSuffix, LoopInfo *LI, DominatorTree *DT, SmallVectorImpl<BasicBlock *> &Blocks) { - assert(OrigLoop->getSubLoops().empty() && - "Loop to be cloned cannot have inner loop"); Function *F = OrigLoop->getHeader()->getParent(); Loop *ParentLoop = OrigLoop->getParentLoop(); + DenseMap<Loop *, Loop *> LMap; Loop *NewLoop = LI->AllocateLoop(); + LMap[OrigLoop] = NewLoop; if (ParentLoop) ParentLoop->addChildLoop(NewLoop); else @@ -765,14 +764,36 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, // Update DominatorTree. DT->addNewBlock(NewPH, LoopDomBB); + for (Loop *CurLoop : OrigLoop->getLoopsInPreorder()) { + Loop *&NewLoop = LMap[CurLoop]; + if (!NewLoop) { + NewLoop = LI->AllocateLoop(); + + // Establish the parent/child relationship. + Loop *OrigParent = CurLoop->getParentLoop(); + assert(OrigParent && "Could not find the original parent loop"); + Loop *NewParentLoop = LMap[OrigParent]; + assert(NewParentLoop && "Could not find the new parent loop"); + + NewParentLoop->addChildLoop(NewLoop); + } + } + for (BasicBlock *BB : OrigLoop->getBlocks()) { + Loop *CurLoop = LI->getLoopFor(BB); + Loop *&NewLoop = LMap[CurLoop]; + assert(NewLoop && "Expecting new loop to be allocated"); + BasicBlock *NewBB = CloneBasicBlock(BB, VMap, NameSuffix, F); VMap[BB] = NewBB; // Update LoopInfo. NewLoop->addBasicBlockToLoop(NewBB, *LI); + if (BB == CurLoop->getHeader()) + NewLoop->moveToHeader(NewBB); - // Add DominatorTree node. After seeing all blocks, update to correct IDom. + // Add DominatorTree node. After seeing all blocks, update to correct + // IDom. DT->addNewBlock(NewBB, NewPH); Blocks.push_back(NewBB); diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 659993aa5478..7ddf59becba9 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -1,9 +1,8 @@ //===- CloneModule.cpp - Clone an entire module ---------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 25d4ae583ecc..fa6d3f8ae873 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -1,9 +1,8 @@ //===- CodeExtractor.cpp - Pull code region into a new function -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -44,6 +44,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -67,6 +68,7 @@ #include <vector> using namespace llvm; +using namespace llvm::PatternMatch; using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "code-extractor" @@ -207,6 +209,9 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, llvm_unreachable("Repeated basic blocks in extraction input"); } + LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName() + << '\n'); + for (auto *BB : Result) { if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca)) return {}; @@ -224,9 +229,11 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, // the subgraph which is being extracted. for (auto *PBB : predecessors(BB)) if (!Result.count(PBB)) { - LLVM_DEBUG( - dbgs() << "No blocks in this region may have entries from " - "outside the region except for the first block!\n"); + LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from " + "outside the region except for the first block!\n" + << "Problematic source BB: " << BB->getName() << "\n" + << "Problematic destination BB: " << PBB->getName() + << "\n"); return {}; } } @@ -236,18 +243,20 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, bool AllowVarArgs, - bool AllowAlloca, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + bool AllowVarArgs, bool AllowAlloca, + std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(AllowVarArgs), + BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(false), + BPI(BPI), AC(AC), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, /* AllowAlloca */ false)), @@ -325,7 +334,7 @@ bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( if (dyn_cast<Constant>(MemAddr)) break; Value *Base = MemAddr->stripInBoundsConstantOffsets(); - if (!dyn_cast<AllocaInst>(Base) || Base == AI) + if (!isa<AllocaInst>(Base) || Base == AI) return false; break; } @@ -401,11 +410,74 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { return CommonExitBlock; } +// Find the pair of life time markers for address 'Addr' that are either +// defined inside the outline region or can legally be shrinkwrapped into the +// outline region. If there are not other untracked uses of the address, return +// the pair of markers if found; otherwise return a pair of nullptr. +CodeExtractor::LifetimeMarkerInfo +CodeExtractor::getLifetimeMarkers(Instruction *Addr, + BasicBlock *ExitBlock) const { + LifetimeMarkerInfo Info; + + for (User *U : Addr->users()) { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { + // Do not handle the case where Addr has multiple start markers. + if (Info.LifeStart) + return {}; + Info.LifeStart = IntrInst; + } + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { + if (Info.LifeEnd) + return {}; + Info.LifeEnd = IntrInst; + } + continue; + } + // Find untracked uses of the address, bail. + if (!definedInRegion(Blocks, U)) + return {}; + } + + if (!Info.LifeStart || !Info.LifeEnd) + return {}; + + Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart); + Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd); + // Do legality check. + if ((Info.SinkLifeStart || Info.HoistLifeEnd) && + !isLegalToShrinkwrapLifetimeMarkers(Addr)) + return {}; + + // Check to see if we have a place to do hoisting, if not, bail. + if (Info.HoistLifeEnd && !ExitBlock) + return {}; + + return Info; +} + void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, BasicBlock *&ExitBlock) const { Function *Func = (*Blocks.begin())->getParent(); ExitBlock = getCommonExitBlock(Blocks); + auto moveOrIgnoreLifetimeMarkers = + [&](const LifetimeMarkerInfo &LMI) -> bool { + if (!LMI.LifeStart) + return false; + if (LMI.SinkLifeStart) { + LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart + << "\n"); + SinkCands.insert(LMI.LifeStart); + } + if (LMI.HoistLifeEnd) { + LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n"); + HoistCands.insert(LMI.LifeEnd); + } + return true; + }; + for (BasicBlock &BB : *Func) { if (Blocks.count(&BB)) continue; @@ -414,95 +486,52 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, if (!AI) continue; - // Find the pair of life time markers for address 'Addr' that are either - // defined inside the outline region or can legally be shrinkwrapped into - // the outline region. If there are not other untracked uses of the - // address, return the pair of markers if found; otherwise return a pair - // of nullptr. - auto GetLifeTimeMarkers = - [&](Instruction *Addr, bool &SinkLifeStart, - bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> { - Instruction *LifeStart = nullptr, *LifeEnd = nullptr; - - for (User *U : Addr->users()) { - IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); - if (IntrInst) { - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { - // Do not handle the case where AI has multiple start markers. - if (LifeStart) - return std::make_pair<Instruction *>(nullptr, nullptr); - LifeStart = IntrInst; - } - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { - if (LifeEnd) - return std::make_pair<Instruction *>(nullptr, nullptr); - LifeEnd = IntrInst; - } - continue; - } - // Find untracked uses of the address, bail. - if (!definedInRegion(Blocks, U)) - return std::make_pair<Instruction *>(nullptr, nullptr); - } - - if (!LifeStart || !LifeEnd) - return std::make_pair<Instruction *>(nullptr, nullptr); - - SinkLifeStart = !definedInRegion(Blocks, LifeStart); - HoistLifeEnd = !definedInRegion(Blocks, LifeEnd); - // Do legality Check. - if ((SinkLifeStart || HoistLifeEnd) && - !isLegalToShrinkwrapLifetimeMarkers(Addr)) - return std::make_pair<Instruction *>(nullptr, nullptr); - - // Check to see if we have a place to do hoisting, if not, bail. - if (HoistLifeEnd && !ExitBlock) - return std::make_pair<Instruction *>(nullptr, nullptr); - - return std::make_pair(LifeStart, LifeEnd); - }; - - bool SinkLifeStart = false, HoistLifeEnd = false; - auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd); - - if (Markers.first) { - if (SinkLifeStart) - SinkCands.insert(Markers.first); + LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock); + bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo); + if (Moved) { + LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n"); SinkCands.insert(AI); - if (HoistLifeEnd) - HoistCands.insert(Markers.second); continue; } - // Follow the bitcast. - Instruction *MarkerAddr = nullptr; + // Follow any bitcasts. + SmallVector<Instruction *, 2> Bitcasts; + SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo; for (User *U : AI->users()) { if (U->stripInBoundsConstantOffsets() == AI) { - SinkLifeStart = false; - HoistLifeEnd = false; Instruction *Bitcast = cast<Instruction>(U); - Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd); - if (Markers.first) { - MarkerAddr = Bitcast; + LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock); + if (LMI.LifeStart) { + Bitcasts.push_back(Bitcast); + BitcastLifetimeInfo.push_back(LMI); continue; } } // Found unknown use of AI. if (!definedInRegion(Blocks, U)) { - MarkerAddr = nullptr; + Bitcasts.clear(); break; } } - if (MarkerAddr) { - if (SinkLifeStart) - SinkCands.insert(Markers.first); - if (!definedInRegion(Blocks, MarkerAddr)) - SinkCands.insert(MarkerAddr); - SinkCands.insert(AI); - if (HoistLifeEnd) - HoistCands.insert(Markers.second); + // Either no bitcasts reference the alloca or there are unknown uses. + if (Bitcasts.empty()) + continue; + + LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n"); + SinkCands.insert(AI); + for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) { + Instruction *BitcastAddr = Bitcasts[I]; + const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I]; + assert(LMI.LifeStart && + "Unsafe to sink bitcast without lifetime markers"); + moveOrIgnoreLifetimeMarkers(LMI); + if (!definedInRegion(Blocks, BitcastAddr)) { + LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr + << "\n"); + SinkCands.insert(BitcastAddr); + } } } } @@ -780,6 +809,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NoBuiltin: case Attribute::NoCapture: case Attribute::NoReturn: + case Attribute::NoSync: case Attribute::None: case Attribute::NonNull: case Attribute::ReadNone: @@ -792,8 +822,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::StructRet: case Attribute::SwiftError: case Attribute::SwiftSelf: + case Attribute::WillReturn: case Attribute::WriteOnly: case Attribute::ZExt: + case Attribute::ImmArg: case Attribute::EndAttrKinds: continue; // Those attributes should be safe to propagate to the extracted function. @@ -803,6 +835,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::InlineHint: case Attribute::MinSize: case Attribute::NoDuplicate: + case Attribute::NoFree: case Attribute::NoImplicitFloat: case Attribute::NoInline: case Attribute::NonLazyBind: @@ -817,6 +850,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::SanitizeMemory: case Attribute::SanitizeThread: case Attribute::SanitizeHWAddress: + case Attribute::SanitizeMemTag: case Attribute::SpeculativeLoadHardening: case Attribute::StackProtect: case Attribute::StackProtectReq: @@ -845,7 +879,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, Instruction *TI = newFunction->begin()->getTerminator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); + RewriteVal = new LoadInst(StructTy->getElementType(i), GEP, + "loadgep_" + inputs[i]->getName(), TI); } else RewriteVal = &*AI++; @@ -880,6 +915,88 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, return newFunction; } +/// Erase lifetime.start markers which reference inputs to the extraction +/// region, and insert the referenced memory into \p LifetimesStart. +/// +/// The extraction region is defined by a set of blocks (\p Blocks), and a set +/// of allocas which will be moved from the caller function into the extracted +/// function (\p SunkAllocas). +static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, + const SetVector<Value *> &SunkAllocas, + SetVector<Value *> &LifetimesStart) { + for (BasicBlock *BB : Blocks) { + for (auto It = BB->begin(), End = BB->end(); It != End;) { + auto *II = dyn_cast<IntrinsicInst>(&*It); + ++It; + if (!II || !II->isLifetimeStartOrEnd()) + continue; + + // Get the memory operand of the lifetime marker. If the underlying + // object is a sunk alloca, or is otherwise defined in the extraction + // region, the lifetime marker must not be erased. + Value *Mem = II->getOperand(1)->stripInBoundsOffsets(); + if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem)) + continue; + + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + LifetimesStart.insert(Mem); + II->eraseFromParent(); + } + } +} + +/// Insert lifetime start/end markers surrounding the call to the new function +/// for objects defined in the caller. +static void insertLifetimeMarkersSurroundingCall( + Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd, + CallInst *TheCall) { + LLVMContext &Ctx = M->getContext(); + auto Int8PtrTy = Type::getInt8PtrTy(Ctx); + auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); + Instruction *Term = TheCall->getParent()->getTerminator(); + + // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts + // needed to satisfy this requirement so they may be reused. + DenseMap<Value *, Value *> Bitcasts; + + // Emit lifetime markers for the pointers given in \p Objects. Insert the + // markers before the call if \p InsertBefore, and after the call otherwise. + auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects, + bool InsertBefore) { + for (Value *Mem : Objects) { + assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() == + TheCall->getFunction()) && + "Input memory not defined in original function"); + Value *&MemAsI8Ptr = Bitcasts[Mem]; + if (!MemAsI8Ptr) { + if (Mem->getType() == Int8PtrTy) + MemAsI8Ptr = Mem; + else + MemAsI8Ptr = + CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); + } + + auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr}); + if (InsertBefore) + Marker->insertBefore(TheCall); + else + Marker->insertBefore(Term); + } + }; + + if (!LifetimesStart.empty()) { + auto StartFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_start, Int8PtrTy); + insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true); + } + + if (!LifetimesEnd.empty()) { + auto EndFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_end, Int8PtrTy); + insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false); + } +} + /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. @@ -897,11 +1014,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, CallInst *call = nullptr; // Add inputs as params, or to be filled into the struct - for (Value *input : inputs) + unsigned ArgNo = 0; + SmallVector<unsigned, 1> SwiftErrorArgs; + for (Value *input : inputs) { if (AggregateArgs) StructValues.push_back(input); - else + else { params.push_back(input); + if (input->isSwiftError()) + SwiftErrorArgs.push_back(ArgNo); + } + ++ArgNo; + } // Create allocas for the outputs for (Value *output : outputs) { @@ -957,13 +1081,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } codeReplacer->getInstList().push_back(call); + // Set swifterror parameter attributes. + for (unsigned SwiftErrArgNo : SwiftErrorArgs) { + call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); + newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); + } + Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); unsigned FirstOut = inputs.size(); if (!AggregateArgs) std::advance(OutputArgBegin, inputs.size()); // Reload the outputs passed in by reference. - Function::arg_iterator OAI = OutputArgBegin; for (unsigned i = 0, e = outputs.size(); i != e; ++i) { Value *Output = nullptr; if (AggregateArgs) { @@ -977,7 +1106,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } else { Output = ReloadOutputs[i]; } - LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); + LoadInst *load = new LoadInst(outputs[i]->getType(), Output, + outputs[i]->getName() + ".reload"); Reloads.push_back(load); codeReplacer->getInstList().push_back(load); std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end()); @@ -986,40 +1116,6 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } - - // Store to argument right after the definition of output value. - auto *OutI = dyn_cast<Instruction>(outputs[i]); - if (!OutI) - continue; - - // Find proper insertion point. - BasicBlock::iterator InsertPt; - // In case OutI is an invoke, we insert the store at the beginning in the - // 'normal destination' BB. Otherwise we insert the store right after OutI. - if (auto *InvokeI = dyn_cast<InvokeInst>(OutI)) - InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); - else if (auto *Phi = dyn_cast<PHINode>(OutI)) - InsertPt = Phi->getParent()->getFirstInsertionPt(); - else - InsertPt = std::next(OutI->getIterator()); - - assert(OAI != newFunction->arg_end() && - "Number of output arguments should match " - "the amount of defined values"); - if (AggregateArgs) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt); - new StoreInst(outputs[i], GEP, &*InsertPt); - // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment OAI, which always - // points to the struct argument, in this case. - } else { - new StoreInst(outputs[i], &*OAI, &*InsertPt); - ++OAI; - } } // Now we can emit a switch statement using the call as a value. @@ -1075,6 +1171,50 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } } + // Store the arguments right after the definition of output value. + // This should be proceeded after creating exit stubs to be ensure that invoke + // result restore will be placed in the outlined function. + Function::arg_iterator OAI = OutputArgBegin; + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + auto *OutI = dyn_cast<Instruction>(outputs[i]); + if (!OutI) + continue; + + // Find proper insertion point. + BasicBlock::iterator InsertPt; + // In case OutI is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after OutI. + if (auto *InvokeI = dyn_cast<InvokeInst>(OutI)) + InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); + else if (auto *Phi = dyn_cast<PHINode>(OutI)) + InsertPt = Phi->getParent()->getFirstInsertionPt(); + else + InsertPt = std::next(OutI->getIterator()); + + Instruction *InsertBefore = &*InsertPt; + assert((InsertBefore->getFunction() == newFunction || + Blocks.count(InsertBefore->getParent())) && + "InsertPt should be in new function"); + assert(OAI != newFunction->arg_end() && + "Number of output arguments should match " + "the amount of defined values"); + if (AggregateArgs) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), + InsertBefore); + new StoreInst(outputs[i], GEP, InsertBefore); + // Since there should be only one struct argument aggregating + // all the output values, we shouldn't increment OAI, which always + // points to the struct argument, in this case. + } else { + new StoreInst(outputs[i], &*OAI, InsertBefore); + ++OAI; + } + } + // Now that we've done the deed, simplify the switch instruction. Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); switch (NumExitBlocks) { @@ -1119,6 +1259,10 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, break; } + // Insert lifetime markers around the reloads of any output values. The + // allocas output values are stored in are only in-use in the codeRepl block. + insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); + return call; } @@ -1133,6 +1277,13 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { // Insert this basic block into the new function newBlocks.push_back(Block); + + // Remove @llvm.assume calls that were moved to the new function from the + // old function's assumption cache. + if (AC) + for (auto &I : *Block) + if (match(&I, m_Intrinsic<Intrinsic::assume>())) + AC->unregisterAssumption(cast<CallInst>(&I)); } } @@ -1181,71 +1332,6 @@ void CodeExtractor::calculateNewCallTerminatorWeights( MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); } -/// Scan the extraction region for lifetime markers which reference inputs. -/// Erase these markers. Return the inputs which were referenced. -/// -/// The extraction region is defined by a set of blocks (\p Blocks), and a set -/// of allocas which will be moved from the caller function into the extracted -/// function (\p SunkAllocas). -static SetVector<Value *> -eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, - const SetVector<Value *> &SunkAllocas) { - SetVector<Value *> InputObjectsWithLifetime; - for (BasicBlock *BB : Blocks) { - for (auto It = BB->begin(), End = BB->end(); It != End;) { - auto *II = dyn_cast<IntrinsicInst>(&*It); - ++It; - if (!II || !II->isLifetimeStartOrEnd()) - continue; - - // Get the memory operand of the lifetime marker. If the underlying - // object is a sunk alloca, or is otherwise defined in the extraction - // region, the lifetime marker must not be erased. - Value *Mem = II->getOperand(1)->stripInBoundsOffsets(); - if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem)) - continue; - - InputObjectsWithLifetime.insert(Mem); - II->eraseFromParent(); - } - } - return InputObjectsWithLifetime; -} - -/// Insert lifetime start/end markers surrounding the call to the new function -/// for objects defined in the caller. -static void insertLifetimeMarkersSurroundingCall( - Module *M, const SetVector<Value *> &InputObjectsWithLifetime, - CallInst *TheCall) { - if (InputObjectsWithLifetime.empty()) - return; - - LLVMContext &Ctx = M->getContext(); - auto Int8PtrTy = Type::getInt8PtrTy(Ctx); - auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); - auto LifetimeStartFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_start, Int8PtrTy); - auto LifetimeEndFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_end, Int8PtrTy); - for (Value *Mem : InputObjectsWithLifetime) { - assert((!isa<Instruction>(Mem) || - cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) && - "Input memory not defined in original function"); - Value *MemAsI8Ptr = nullptr; - if (Mem->getType() == Int8PtrTy) - MemAsI8Ptr = Mem; - else - MemAsI8Ptr = - CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); - - auto StartMarker = - CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr}); - StartMarker->insertBefore(TheCall); - auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr}); - EndMarker->insertAfter(TheCall); - } -} - Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -1348,10 +1434,24 @@ Function *CodeExtractor::extractCodeRegion() { // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs, SinkingCands); - // Now sink all instructions which only have non-phi uses inside the region - for (auto *II : SinkingCands) - cast<Instruction>(II)->moveBefore(*newFuncRoot, - newFuncRoot->getFirstInsertionPt()); + // Now sink all instructions which only have non-phi uses inside the region. + // Group the allocas at the start of the block, so that any bitcast uses of + // the allocas are well-defined. + AllocaInst *FirstSunkAlloca = nullptr; + for (auto *II : SinkingCands) { + if (auto *AI = dyn_cast<AllocaInst>(II)) { + AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt()); + if (!FirstSunkAlloca) + FirstSunkAlloca = AI; + } + } + assert((SinkingCands.empty() || FirstSunkAlloca) && + "Did not expect a sink candidate without any allocas"); + for (auto *II : SinkingCands) { + if (!isa<AllocaInst>(II)) { + cast<Instruction>(II)->moveAfter(FirstSunkAlloca); + } + } if (!HoistingCands.empty()) { auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); @@ -1361,11 +1461,11 @@ Function *CodeExtractor::extractCodeRegion() { } // Collect objects which are inputs to the extraction region and also - // referenced by lifetime start/end markers within it. The effects of these + // referenced by lifetime start markers within it. The effects of these // markers must be replicated in the calling function to prevent the stack // coloring pass from merging slots which store input objects. - ValueSet InputObjectsWithLifetime = - eraseLifetimeMarkersOnInputs(Blocks, SinkingCands); + ValueSet LifetimesStart; + eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); // Construct new function based on inputs/outputs & add allocas for all defs. Function *newFunction = @@ -1388,8 +1488,8 @@ Function *CodeExtractor::extractCodeRegion() { // Replicate the effects of any lifetime start/end markers which referenced // input objects in the extraction region by placing markers around the call. - insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), - InputObjectsWithLifetime, TheCall); + insertLifetimeMarkersSurroundingCall( + oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall); // Propagate personality info to the new function if there is one. if (oldFunction->hasPersonalityFn()) diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index 4e7da7d0449f..069a86f6ab33 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -1,9 +1,8 @@ //===- CtorUtils.cpp - Helpers for working with global_ctors ----*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp index 975b363859a9..5f53d794fe8a 100644 --- a/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -1,9 +1,8 @@ //===- DemoteRegToStack.cpp - Move a virtual register to the stack --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -73,7 +72,8 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, Value *&V = Loads[PN->getIncomingBlock(i)]; if (!V) { // Insert the load into the predecessor block - V = new LoadInst(Slot, I.getName()+".reload", VolatileLoads, + V = new LoadInst(I.getType(), Slot, I.getName() + ".reload", + VolatileLoads, PN->getIncomingBlock(i)->getTerminator()); } PN->setIncomingValue(i, V); @@ -81,7 +81,8 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, } else { // If this is a normal instruction, just insert a load. - Value *V = new LoadInst(Slot, I.getName()+".reload", VolatileLoads, U); + Value *V = new LoadInst(I.getType(), Slot, I.getName() + ".reload", + VolatileLoads, U); U->replaceUsesOfWith(&I, V); } } @@ -142,7 +143,8 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) /* empty */; // Don't insert before PHI nodes or landingpad instrs. - Value *V = new LoadInst(Slot, P->getName() + ".reload", &*InsertPt); + Value *V = + new LoadInst(P->getType(), Slot, P->getName() + ".reload", &*InsertPt); P->replaceAllUsesWith(V); // Delete PHI. diff --git a/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 569ea58a3047..4aa40eeadda4 100644 --- a/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -1,9 +1,8 @@ //===- EntryExitInstrumenter.cpp - Function Entry/Exit Instrumentation ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -31,7 +30,7 @@ static void insertCall(Function &CurFn, StringRef Func, Func == "__mcount" || Func == "_mcount" || Func == "__cyg_profile_func_enter_bare") { - Constant *Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); + FunctionCallee Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); CallInst *Call = CallInst::Create(Fn, "", InsertionPt); Call->setDebugLoc(DL); return; @@ -40,7 +39,7 @@ static void insertCall(Function &CurFn, StringRef Func, if (Func == "__cyg_profile_func_enter" || Func == "__cyg_profile_func_exit") { Type *ArgTypes[] = {Type::getInt8PtrTy(C), Type::getInt8PtrTy(C)}; - Constant *Fn = M.getOrInsertFunction( + FunctionCallee Fn = M.getOrInsertFunction( Func, FunctionType::get(Type::getVoidTy(C), ArgTypes, false)); Instruction *RetAddr = CallInst::Create( diff --git a/lib/Transforms/Utils/EscapeEnumerator.cpp b/lib/Transforms/Utils/EscapeEnumerator.cpp index 762a374c135c..914babeb6829 100644 --- a/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -1,9 +1,8 @@ //===- EscapeEnumerator.cpp -----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -19,7 +18,7 @@ #include "llvm/IR/Module.h" using namespace llvm; -static Constant *getDefaultPersonalityFn(Module *M) { +static FunctionCallee getDefaultPersonalityFn(Module *M) { LLVMContext &C = M->getContext(); Triple T(M->getTargetTriple()); EHPersonality Pers = getDefaultEHPersonality(T); @@ -69,8 +68,8 @@ IRBuilder<> *EscapeEnumerator::Next() { BasicBlock *CleanupBB = BasicBlock::Create(C, CleanupBBName, &F); Type *ExnTy = StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C)); if (!F.hasPersonalityFn()) { - Constant *PersFn = getDefaultPersonalityFn(F.getParent()); - F.setPersonalityFn(PersFn); + FunctionCallee PersFn = getDefaultPersonalityFn(F.getParent()); + F.setPersonalityFn(cast<Constant>(PersFn.getCallee())); } if (isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) { diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index e875cd686b00..0e203f4e075d 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -1,9 +1,8 @@ //===- Evaluator.cpp - LLVM IR evaluator ----------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -175,6 +174,34 @@ static bool isSimpleEnoughPointerToCommit(Constant *C) { return false; } +/// Apply 'Func' to Ptr. If this returns nullptr, introspect the pointer's +/// type and walk down through the initial elements to obtain additional +/// pointers to try. Returns the first non-null return value from Func, or +/// nullptr if the type can't be introspected further. +static Constant * +evaluateBitcastFromPtr(Constant *Ptr, const DataLayout &DL, + const TargetLibraryInfo *TLI, + std::function<Constant *(Constant *)> Func) { + Constant *Val; + while (!(Val = Func(Ptr))) { + // If Ty is a struct, we can convert the pointer to the struct + // into a pointer to its first member. + // FIXME: This could be extended to support arrays as well. + Type *Ty = cast<PointerType>(Ptr->getType())->getElementType(); + if (!isa<StructType>(Ty)) + break; + + IntegerType *IdxTy = IntegerType::get(Ty->getContext(), 32); + Constant *IdxZero = ConstantInt::get(IdxTy, 0, false); + Constant *const IdxList[] = {IdxZero, IdxZero}; + + Ptr = ConstantExpr::getGetElementPtr(Ty, Ptr, IdxList); + if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) + Ptr = FoldedPtr; + } + return Val; +} + static Constant *getInitializer(Constant *C) { auto *GV = dyn_cast<GlobalVariable>(C); return GV && GV->hasDefinitiveInitializer() ? GV->getInitializer() : nullptr; @@ -185,8 +212,14 @@ static Constant *getInitializer(Constant *C) { Constant *Evaluator::ComputeLoadResult(Constant *P) { // If this memory location has been recently stored, use the stored value: it // is the most up-to-date. - DenseMap<Constant*, Constant*>::const_iterator I = MutatedMemory.find(P); - if (I != MutatedMemory.end()) return I->second; + auto findMemLoc = [this](Constant *Ptr) { + DenseMap<Constant *, Constant *>::const_iterator I = + MutatedMemory.find(Ptr); + return I != MutatedMemory.end() ? I->second : nullptr; + }; + + if (Constant *Val = findMemLoc(P)) + return Val; // Access it. if (GlobalVariable *GV = dyn_cast<GlobalVariable>(P)) { @@ -204,13 +237,17 @@ Constant *Evaluator::ComputeLoadResult(Constant *P) { break; // Handle a constantexpr bitcast. case Instruction::BitCast: - Constant *Val = getVal(CE->getOperand(0)); - auto MM = MutatedMemory.find(Val); - auto *I = (MM != MutatedMemory.end()) ? MM->second - : getInitializer(CE->getOperand(0)); - if (I) + // We're evaluating a load through a pointer that was bitcast to a + // different type. See if the "from" pointer has recently been stored. + // If it hasn't, we may still be able to find a stored pointer by + // introspecting the type. + Constant *Val = + evaluateBitcastFromPtr(CE->getOperand(0), DL, TLI, findMemLoc); + if (!Val) + Val = getInitializer(CE->getOperand(0)); + if (Val) return ConstantFoldLoadThroughBitcast( - I, P->getType()->getPointerElementType(), DL); + Val, P->getType()->getPointerElementType(), DL); break; } } @@ -330,37 +367,26 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, << "Attempting to resolve bitcast on constant ptr.\n"); // If we're evaluating a store through a bitcast, then we need // to pull the bitcast off the pointer type and push it onto the - // stored value. - Ptr = CE->getOperand(0); - - Type *NewTy = cast<PointerType>(Ptr->getType())->getElementType(); - - // In order to push the bitcast onto the stored value, a bitcast - // from NewTy to Val's type must be legal. If it's not, we can try - // introspecting NewTy to find a legal conversion. - Constant *NewVal; - while (!(NewVal = ConstantFoldLoadThroughBitcast(Val, NewTy, DL))) { - // If NewTy is a struct, we can convert the pointer to the struct - // into a pointer to its first member. - // FIXME: This could be extended to support arrays as well. - if (StructType *STy = dyn_cast<StructType>(NewTy)) { - NewTy = STy->getTypeAtIndex(0U); - - IntegerType *IdxTy = IntegerType::get(NewTy->getContext(), 32); - Constant *IdxZero = ConstantInt::get(IdxTy, 0, false); - Constant * const IdxList[] = {IdxZero, IdxZero}; - - Ptr = ConstantExpr::getGetElementPtr(nullptr, Ptr, IdxList); - if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) - Ptr = FoldedPtr; - - // If we can't improve the situation by introspecting NewTy, - // we have to give up. - } else { - LLVM_DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " - "evaluate.\n"); - return false; + // stored value. In order to push the bitcast onto the stored value, + // a bitcast from the pointer's element type to Val's type must be + // legal. If it's not, we can try introspecting the type to find a + // legal conversion. + + auto castValTy = [&](Constant *P) -> Constant * { + Type *Ty = cast<PointerType>(P->getType())->getElementType(); + if (Constant *FV = ConstantFoldLoadThroughBitcast(Val, Ty, DL)) { + Ptr = P; + return FV; } + return nullptr; + }; + + Constant *NewVal = + evaluateBitcastFromPtr(CE->getOperand(0), DL, TLI, castValTy); + if (!NewVal) { + LLVM_DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " + "evaluate.\n"); + return false; } Val = NewVal; @@ -541,7 +567,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, if (Callee->isDeclaration()) { // If this is a function we can constant fold, do it. - if (Constant *C = ConstantFoldCall(CS, Callee, Formals, TLI)) { + if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), + Callee, Formals, TLI)) { InstResult = castCallResultIfNeeded(CS.getCalledValue(), C); if (!InstResult) return false; diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index d9778f4a1fb7..0c52e6f3703b 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -1,9 +1,8 @@ //===- FlatternCFG.cpp - Code to perform CFG flattening -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index a717d9b72819..a9b28754c8e9 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -1,9 +1,8 @@ //===- FunctionComparator.h - Function Comparator -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -114,6 +113,19 @@ int FunctionComparator::cmpAttrs(const AttributeList L, for (; LI != LE && RI != RE; ++LI, ++RI) { Attribute LA = *LI; Attribute RA = *RI; + if (LA.isTypeAttribute() && RA.isTypeAttribute()) { + if (LA.getKindAsEnum() != RA.getKindAsEnum()) + return cmpNumbers(LA.getKindAsEnum(), RA.getKindAsEnum()); + + Type *TyL = LA.getValueAsType(); + Type *TyR = RA.getValueAsType(); + if (TyL && TyR) + return cmpTypes(TyL, TyR); + + // Two pointers, at least one null, so the comparison result is + // independent of the value of a real pointer. + return cmpNumbers((uint64_t)TyL, (uint64_t)TyR); + } if (LA < RA) return -1; if (RA < LA) @@ -557,31 +569,20 @@ int FunctionComparator::cmpOperations(const Instruction *L, } if (const CmpInst *CI = dyn_cast<CmpInst>(L)) return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate()); - if (const CallInst *CI = dyn_cast<CallInst>(L)) { - if (int Res = cmpNumbers(CI->getCallingConv(), - cast<CallInst>(R)->getCallingConv())) + if (auto CSL = CallSite(const_cast<Instruction *>(L))) { + auto CSR = CallSite(const_cast<Instruction *>(R)); + if (int Res = cmpNumbers(CSL.getCallingConv(), CSR.getCallingConv())) return Res; - if (int Res = - cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes())) + if (int Res = cmpAttrs(CSL.getAttributes(), CSR.getAttributes())) return Res; - if (int Res = cmpOperandBundlesSchema(CI, R)) - return Res; - return cmpRangeMetadata( - CI->getMetadata(LLVMContext::MD_range), - cast<CallInst>(R)->getMetadata(LLVMContext::MD_range)); - } - if (const InvokeInst *II = dyn_cast<InvokeInst>(L)) { - if (int Res = cmpNumbers(II->getCallingConv(), - cast<InvokeInst>(R)->getCallingConv())) + if (int Res = cmpOperandBundlesSchema(L, R)) return Res; - if (int Res = - cmpAttrs(II->getAttributes(), cast<InvokeInst>(R)->getAttributes())) - return Res; - if (int Res = cmpOperandBundlesSchema(II, R)) - return Res; - return cmpRangeMetadata( - II->getMetadata(LLVMContext::MD_range), - cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range)); + if (const CallInst *CI = dyn_cast<CallInst>(L)) + if (int Res = cmpNumbers(CI->getTailCallKind(), + cast<CallInst>(R)->getTailCallKind())) + return Res; + return cmpRangeMetadata(L->getMetadata(LLVMContext::MD_range), + R->getMetadata(LLVMContext::MD_range)); } if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { ArrayRef<unsigned> LIndices = IVI->getIndices(); diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index a9772e31da50..c9cc0990f237 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -1,9 +1,8 @@ //===- lib/Transforms/Utils/FunctionImportUtils.cpp - Importing utilities -===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -130,7 +129,7 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, // definitions upon import, so that they are available for inlining // and/or optimization, but are turned into declarations later // during the EliminateAvailableExternally pass. - if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + if (doImportAsDefinition(SGV) && !isa<GlobalAlias>(SGV)) return GlobalValue::AvailableExternallyLinkage; // An imported external declaration stays external. return SGV->getLinkage(); @@ -159,7 +158,7 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, // equivalent, so the issue described above for weak_any does not exist, // and the definition can be imported. It can be treated similarly // to an imported externally visible global value. - if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + if (doImportAsDefinition(SGV) && !isa<GlobalAlias>(SGV)) return GlobalValue::AvailableExternallyLinkage; else return GlobalValue::ExternalLinkage; @@ -177,7 +176,7 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, // If we are promoting the local to global scope, it is handled // similarly to a normal externally visible global. if (DoPromote) { - if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) + if (doImportAsDefinition(SGV) && !isa<GlobalAlias>(SGV)) return GlobalValue::AvailableExternallyLinkage; else return GlobalValue::ExternalLinkage; @@ -230,11 +229,11 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { } } - // Mark read-only variables which can be imported with specific attribute. - // We can't internalize them now because IRMover will fail to link variable - // definitions to their external declarations during ThinLTO import. We'll - // internalize read-only variables later, after import is finished. - // See internalizeImmutableGVs. + // Mark read/write-only variables which can be imported with specific + // attribute. We can't internalize them now because IRMover will fail + // to link variable definitions to their external declarations during + // ThinLTO import. We'll internalize read-only variables later, after + // import is finished. See internalizeGVsAfterImport. // // If global value dead stripping is not enabled in summary then // propagateConstants hasn't been run. We can't internalize GV @@ -242,13 +241,16 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { if (!GV.isDeclaration() && VI && ImportIndex.withGlobalValueDeadStripping()) { const auto &SL = VI.getSummaryList(); auto *GVS = SL.empty() ? nullptr : dyn_cast<GlobalVarSummary>(SL[0].get()); - if (GVS && GVS->isReadOnly()) + // At this stage "maybe" is "definitely" + if (GVS && (GVS->maybeReadOnly() || GVS->maybeWriteOnly())) cast<GlobalVariable>(&GV)->addAttribute("thinlto-internalize"); } bool DoPromote = false; if (GV.hasLocalLinkage() && ((DoPromote = shouldPromoteLocalToGlobal(&GV)) || isPerformingImport())) { + // Save the original name string before we rename GV below. + auto Name = GV.getName().str(); // Once we change the name or linkage it is difficult to determine // again whether we should promote since shouldPromoteLocalToGlobal needs // to locate the summary (based on GUID from name and linkage). Therefore, @@ -257,6 +259,12 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { GV.setLinkage(getLinkage(&GV, DoPromote)); if (!GV.hasLocalLinkage()) GV.setVisibility(GlobalValue::HiddenVisibility); + + // If we are renaming a COMDAT leader, ensure that we record the COMDAT + // for later renaming as well. This is required for COFF. + if (const auto *C = GV.getComdat()) + if (C->getName() == Name) + RenamedComdats.try_emplace(C, M.getOrInsertComdat(GV.getName())); } else GV.setLinkage(getLinkage(&GV, /* DoPromote */ false)); @@ -281,6 +289,16 @@ void FunctionImportGlobalProcessing::processGlobalsForThinLTO() { processGlobalForThinLTO(SF); for (GlobalAlias &GA : M.aliases()) processGlobalForThinLTO(GA); + + // Replace any COMDATS that required renaming (because the COMDAT leader was + // promoted and renamed). + if (!RenamedComdats.empty()) + for (auto &GO : M.global_objects()) + if (auto *C = GO.getComdat()) { + auto Replacement = RenamedComdats.find(C); + if (Replacement != RenamedComdats.end()) + GO.setComdat(Replacement->second); + } } bool FunctionImportGlobalProcessing::run() { diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index ff6970db47da..a2942869130d 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -1,9 +1,8 @@ //===-- GlobalStatus.cpp - Compute status info for globals -----------------==// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Utils/GuardUtils.cpp b/lib/Transforms/Utils/GuardUtils.cpp index 08de0a4c53e9..34c32d9c0c98 100644 --- a/lib/Transforms/Utils/GuardUtils.cpp +++ b/lib/Transforms/Utils/GuardUtils.cpp @@ -1,9 +1,8 @@ //===-- GuardUtils.cpp - Utils for work with guards -------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Utils that are used to perform transformations related to guards and their diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index 02482c550321..8041e66e6c4c 100644 --- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -1,9 +1,8 @@ //===-- ImportedFunctionsInliningStats.cpp ----------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Generating inliner statistics for imported functions, mostly useful for diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index 623fe91a5a60..a7f0f7ac5d61 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -1,9 +1,8 @@ //===- InlineFunction.cpp - Code to perform function inlining -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -85,16 +84,10 @@ PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", cl::init(true), cl::Hidden, cl::desc("Convert align attributes to assumptions during inlining.")); -llvm::InlineResult llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, - bool InsertLifetime) { - return InlineFunction(CallSite(CI), IFI, CalleeAAR, InsertLifetime); -} - -llvm::InlineResult llvm::InlineFunction(InvokeInst *II, InlineFunctionInfo &IFI, +llvm::InlineResult llvm::InlineFunction(CallBase *CB, InlineFunctionInfo &IFI, AAResults *CalleeAAR, bool InsertLifetime) { - return InlineFunction(CallSite(II), IFI, CalleeAAR, InsertLifetime); + return InlineFunction(CallSite(CB), IFI, CalleeAAR, InsertLifetime); } namespace { @@ -1042,11 +1035,10 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, SmallSetVector<const Argument *, 4> NAPtrArgs; for (const Value *V : PtrArgs) { - SmallVector<Value *, 4> Objects; - GetUnderlyingObjects(const_cast<Value*>(V), - Objects, DL, /* LI = */ nullptr); + SmallVector<const Value *, 4> Objects; + GetUnderlyingObjects(V, Objects, DL, /* LI = */ nullptr); - for (Value *O : Objects) + for (const Value *O : Objects) ObjSet.insert(O); } @@ -1216,14 +1208,14 @@ static void UpdateCallGraphAfterInlining(CallSite CS, // If the call was inlined, but then constant folded, there is no edge to // add. Check for this case. - Instruction *NewCall = dyn_cast<Instruction>(VMI->second); + auto *NewCall = dyn_cast<CallBase>(VMI->second); if (!NewCall) continue; // We do not treat intrinsic calls like real function calls because we // expect them to become inline code; do not add an edge for an intrinsic. - CallSite CS = CallSite(NewCall); - if (CS && CS.getCalledFunction() && CS.getCalledFunction()->isIntrinsic()) + if (NewCall->getCalledFunction() && + NewCall->getCalledFunction()->isIntrinsic()) continue; // Remember that this call site got inlined for the client of @@ -1236,19 +1228,19 @@ static void UpdateCallGraphAfterInlining(CallSite CS, // destination. This can also happen if the call graph node of the caller // was just unnecessarily imprecise. if (!I->second->getFunction()) - if (Function *F = CallSite(NewCall).getCalledFunction()) { + if (Function *F = NewCall->getCalledFunction()) { // Indirect call site resolved to direct call. - CallerNode->addCalledFunction(CallSite(NewCall), CG[F]); + CallerNode->addCalledFunction(NewCall, CG[F]); continue; } - CallerNode->addCalledFunction(CallSite(NewCall), I->second); + CallerNode->addCalledFunction(NewCall, I->second); } // Update the call graph by deleting the edge from Callee to Caller. We must // do this after the loop above in case Caller and Callee are the same. - CallerNode->removeCallEdgeFor(CS); + CallerNode->removeCallEdgeFor(*cast<CallBase>(CS.getInstruction())); } static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, @@ -1353,6 +1345,44 @@ static bool allocaWouldBeStaticInEntry(const AllocaInst *AI ) { return isa<Constant>(AI->getArraySize()) && !AI->isUsedWithInAlloca(); } +/// Returns a DebugLoc for a new DILocation which is a clone of \p OrigDL +/// inlined at \p InlinedAt. \p IANodes is an inlined-at cache. +static DebugLoc inlineDebugLoc(DebugLoc OrigDL, DILocation *InlinedAt, + LLVMContext &Ctx, + DenseMap<const MDNode *, MDNode *> &IANodes) { + auto IA = DebugLoc::appendInlinedAt(OrigDL, InlinedAt, Ctx, IANodes); + return DebugLoc::get(OrigDL.getLine(), OrigDL.getCol(), OrigDL.getScope(), + IA); +} + +/// Returns the LoopID for a loop which has has been cloned from another +/// function for inlining with the new inlined-at start and end locs. +static MDNode *inlineLoopID(const MDNode *OrigLoopId, DILocation *InlinedAt, + LLVMContext &Ctx, + DenseMap<const MDNode *, MDNode *> &IANodes) { + assert(OrigLoopId && OrigLoopId->getNumOperands() > 0 && + "Loop ID needs at least one operand"); + assert(OrigLoopId && OrigLoopId->getOperand(0).get() == OrigLoopId && + "Loop ID should refer to itself"); + + // Save space for the self-referential LoopID. + SmallVector<Metadata *, 4> MDs = {nullptr}; + + for (unsigned i = 1; i < OrigLoopId->getNumOperands(); ++i) { + Metadata *MD = OrigLoopId->getOperand(i); + // Update the DILocations to encode the inlined-at metadata. + if (DILocation *DL = dyn_cast<DILocation>(MD)) + MDs.push_back(inlineDebugLoc(DL, InlinedAt, Ctx, IANodes)); + else + MDs.push_back(MD); + } + + MDNode *NewLoopID = MDNode::getDistinct(Ctx, MDs); + // Insert the self-referential LoopID. + NewLoopID->replaceOperandWith(0, NewLoopID); + return NewLoopID; +} + /// Update inlined instructions' line numbers to /// to encode location where these instructions are inlined. static void fixupLineNumbers(Function *Fn, Function::iterator FI, @@ -1378,10 +1408,17 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, for (; FI != Fn->end(); ++FI) { for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ++BI) { + // Loop metadata needs to be updated so that the start and end locs + // reference inlined-at locations. + if (MDNode *LoopID = BI->getMetadata(LLVMContext::MD_loop)) { + MDNode *NewLoopID = + inlineLoopID(LoopID, InlinedAtNode, BI->getContext(), IANodes); + BI->setMetadata(LLVMContext::MD_loop, NewLoopID); + } + if (DebugLoc DL = BI->getDebugLoc()) { - auto IA = DebugLoc::appendInlinedAt(DL, InlinedAtNode, BI->getContext(), - IANodes); - auto IDL = DebugLoc::get(DL.getLine(), DL.getCol(), DL.getScope(), IA); + DebugLoc IDL = + inlineDebugLoc(DL, InlinedAtNode, BI->getContext(), IANodes); BI->setDebugLoc(IDL); continue; } @@ -1448,47 +1485,45 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, CalleeEntryCount.getCount() < 1) return; auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; - uint64_t CallCount = + int64_t CallCount = std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0, CalleeEntryCount.getCount()); - - for (auto const &Entry : VMap) - if (isa<CallInst>(Entry.first)) - if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) - CI->updateProfWeight(CallCount, CalleeEntryCount.getCount()); - for (BasicBlock &BB : *Callee) - // No need to update the callsite if it is pruned during inlining. - if (VMap.count(&BB)) - for (Instruction &I : BB) - if (CallInst *CI = dyn_cast<CallInst>(&I)) - CI->updateProfWeight(CalleeEntryCount.getCount() - CallCount, - CalleeEntryCount.getCount()); + updateProfileCallee(Callee, -CallCount, &VMap); } -/// Update the entry count of callee after inlining. -/// -/// The callsite's block count is subtracted from the callee's function entry -/// count. -static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, - Instruction *CallInst, Function *Callee, - ProfileSummaryInfo *PSI) { - // If the callee has a original count of N, and the estimated count of - // callsite is M, the new callee count is set to N - M. M is estimated from - // the caller's entry count, its entry block frequency and the block frequency - // of the callsite. +void llvm::updateProfileCallee( + Function *Callee, int64_t entryDelta, + const ValueMap<const Value *, WeakTrackingVH> *VMap) { auto CalleeCount = Callee->getEntryCount(); - if (!CalleeCount.hasValue() || !PSI) - return; - auto CallCount = PSI->getProfileCount(CallInst, CallerBFI); - if (!CallCount.hasValue()) + if (!CalleeCount.hasValue()) return; + + uint64_t priorEntryCount = CalleeCount.getCount(); + uint64_t newEntryCount; + // Since CallSiteCount is an estimate, it could exceed the original callee - // count and has to be set to 0. - if (CallCount.getValue() > CalleeCount.getCount()) - CalleeCount.setCount(0); + // count and has to be set to 0 so guard against underflow. + if (entryDelta < 0 && static_cast<uint64_t>(-entryDelta) > priorEntryCount) + newEntryCount = 0; else - CalleeCount.setCount(CalleeCount.getCount() - CallCount.getValue()); - Callee->setEntryCount(CalleeCount); + newEntryCount = priorEntryCount + entryDelta; + + Callee->setEntryCount(newEntryCount); + + // During inlining ? + if (VMap) { + uint64_t cloneEntryCount = priorEntryCount - newEntryCount; + for (auto const &Entry : *VMap) + if (isa<CallInst>(Entry.first)) + if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) + CI->updateProfWeight(cloneEntryCount, priorEntryCount); + } + for (BasicBlock &BB : *Callee) + // No need to update the callsite if it is pruned during inlining. + if (!VMap || VMap->count(&BB)) + for (Instruction &I : BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) + CI->updateProfWeight(newEntryCount, priorEntryCount); } /// This function inlines the called function into the basic block of the @@ -1507,6 +1542,10 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, assert(TheCall->getParent() && TheCall->getFunction() && "Instruction not in function!"); + // FIXME: we don't inline callbr yet. + if (isa<CallBrInst>(TheCall)) + return false; + // If IFI has any state in it, zap it before we fill it in. IFI.reset(); @@ -1684,8 +1723,6 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, updateCallProfile(CalledFunc, VMap, CalledFunc->getEntryCount(), TheCall, IFI.PSI, IFI.CallerBFI); - // Update the profile count of callee. - updateCalleeCount(IFI.CallerBFI, OrigBB, TheCall, CalledFunc, IFI.PSI); // Inject byval arguments initialization. for (std::pair<Value*, Value*> &Init : ByValInit) @@ -1734,6 +1771,8 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Instruction *NewI = nullptr; if (isa<CallInst>(I)) NewI = CallInst::Create(cast<CallInst>(I), OpDefs, I); + else if (isa<CallBrInst>(I)) + NewI = CallBrInst::Create(cast<CallBrInst>(I), OpDefs, I); else NewI = InvokeInst::Create(cast<InvokeInst>(I), OpDefs, I); @@ -1817,8 +1856,7 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Move any dbg.declares describing the allocas into the entry basic block. DIBuilder DIB(*Caller->getParent()); for (auto &AI : IFI.StaticAllocas) - replaceDbgDeclareForAlloca(AI, AI, DIB, DIExpression::NoDeref, 0, - DIExpression::NoDeref); + replaceDbgDeclareForAlloca(AI, AI, DIB, DIExpression::ApplyOffset, 0); } SmallVector<Value*,4> VarArgsToForward; @@ -1869,10 +1907,8 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Add VarArgs to existing parameters. SmallVector<Value *, 6> Params(CI->arg_operands()); Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); - CallInst *NewCI = - CallInst::Create(CI->getCalledFunction() ? CI->getCalledFunction() - : CI->getCalledValue(), - Params, "", CI); + CallInst *NewCI = CallInst::Create( + CI->getFunctionType(), CI->getCalledOperand(), Params, "", CI); NewCI->setDebugLoc(CI->getDebugLoc()); NewCI->setAttributes(Attrs); NewCI->setCallingConv(CI->getCallingConv()); @@ -2038,6 +2074,8 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Instruction *NewInst; if (CS.isCall()) NewInst = CallInst::Create(cast<CallInst>(I), OpBundles, I); + else if (CS.isCallBr()) + NewInst = CallBrInst::Create(cast<CallBrInst>(I), OpBundles, I); else NewInst = InvokeInst::Create(cast<InvokeInst>(I), OpBundles, I); NewInst->takeName(I); diff --git a/lib/Transforms/Utils/InstructionNamer.cpp b/lib/Transforms/Utils/InstructionNamer.cpp index 003721f2b939..6c4fc1ceb991 100644 --- a/lib/Transforms/Utils/InstructionNamer.cpp +++ b/lib/Transforms/Utils/InstructionNamer.cpp @@ -1,9 +1,8 @@ //===- InstructionNamer.cpp - Give anonymous instructions names -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 4a359b99bebd..9082049c82da 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -1,9 +1,8 @@ //===-- IntegerDivision.cpp - Expand integer division ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index 53d444b309d5..29e7c5260f46 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -1,9 +1,8 @@ //===-- LCSSA.cpp - Convert loops into loop-closed SSA form ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -32,11 +31,12 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -45,6 +45,7 @@ #include "llvm/IR/PredIteratorCache.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" using namespace llvm; @@ -198,6 +199,17 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, continue; } + // If we added a single PHI, it must dominate all uses and we can directly + // rename it. + if (AddedPHIs.size() == 1) { + // Tell the VHs that the uses changed. This updates SCEV's caches. + // We might call ValueIsRAUWd multiple times for the same value. + if (UseToRewrite->get()->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(*UseToRewrite, AddedPHIs[0]); + UseToRewrite->set(AddedPHIs[0]); + continue; + } + // Otherwise, do full PHI insertion. SSAUpdate.RewriteUse(*UseToRewrite); } @@ -211,9 +223,12 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, BasicBlock *UserBB = DVI->getParent(); if (InstBB == UserBB || L->contains(UserBB)) continue; - // We currently only handle debug values residing in blocks where we have - // inserted a PHI instruction. - if (Value *V = SSAUpdate.FindValueForBlock(UserBB)) + // We currently only handle debug values residing in blocks that were + // traversed while rewriting the uses. If we inserted just a single PHI, + // we will handle all relevant debug values. + Value *V = AddedPHIs.size() == 1 ? AddedPHIs[0] + : SSAUpdate.FindValueForBlock(UserBB); + if (V) DVI->setOperand(0, MetadataAsValue::get(Ctx, ValueAsMetadata::get(V))); } @@ -306,6 +321,12 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution *SE) { bool Changed = false; +#ifdef EXPENSIVE_CHECKS + // Verify all sub-loops are in LCSSA form already. + for (Loop *SubLoop: L) + assert(SubLoop->isRecursivelyLCSSAForm(DT, *LI) && "Subloop not in LCSSA!"); +#endif + SmallVector<BasicBlock *, 8> ExitBlocks; L.getExitBlocks(ExitBlocks); if (ExitBlocks.empty()) @@ -325,6 +346,10 @@ bool llvm::formLCSSA(Loop &L, DominatorTree &DT, LoopInfo *LI, // Look at all the instructions in the loop, checking to see if they have uses // outside the loop. If so, put them into the worklist to rewrite those uses. for (BasicBlock *BB : BlocksDominatingExits) { + // Skip blocks that are part of any sub-loops, they must be in LCSSA + // already. + if (LI->getLoopFor(BB) != &L) + continue; for (Instruction &I : *BB) { // Reject two common cases fast: instructions with no uses (like stores) // and instructions with one use that is in the same block as this. @@ -419,6 +444,8 @@ struct LCSSAWrapperPass : public FunctionPass { AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addPreserved<SCEVAAWrapperPass>(); + AU.addPreserved<BranchProbabilityInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); // This is needed to perform LCSSA verification inside LPPassManager AU.addRequired<LCSSAVerificationPass>(); @@ -462,5 +489,9 @@ PreservedAnalyses LCSSAPass::run(Function &F, FunctionAnalysisManager &AM) { PA.preserve<GlobalsAA>(); PA.preserve<SCEVAA>(); PA.preserve<ScalarEvolutionAnalysis>(); + // BPI maps terminators to probabilities, since we don't modify the CFG, no + // updates are needed to preserve it. + PA.preserve<BranchProbabilityAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index e1592c867636..8c67d1dc6eb3 100644 --- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -1,9 +1,8 @@ //===-- LibCallsShrinkWrap.cpp ----------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index 499e611acb57..39b6b889f91c 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -1,9 +1,8 @@ //===- Local.cpp - Functions to perform local transformations -------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -27,6 +26,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" @@ -49,7 +49,6 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -92,6 +91,10 @@ using namespace llvm::PatternMatch; STATISTIC(NumRemoved, "Number of unreachable basic blocks removed"); +// Max recursion depth for collectBitParts used when detecting bswap and +// bitreverse idioms +static const unsigned BitPartRecursionMaxDepth = 64; + //===----------------------------------------------------------------------===// // Local constant propagation. // @@ -129,7 +132,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, Builder.CreateBr(Destination); BI->eraseFromParent(); if (DTU) - DTU->deleteEdgeRelaxed(BB, OldDest); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, OldDest}}); return true; } @@ -205,7 +208,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, i = SI->removeCase(i); e = SI->case_end(); if (DTU) - DTU->deleteEdgeRelaxed(ParentBB, DefaultDest); + DTU->applyUpdatesPermissive( + {{DominatorTree::Delete, ParentBB, DefaultDest}}); continue; } @@ -253,7 +257,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, if (DeleteDeadConditions) RecursivelyDeleteTriviallyDeadInstructions(Cond, TLI); if (DTU) - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); return true; } @@ -331,7 +335,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, } if (DTU) - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); return true; } } @@ -416,8 +420,8 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (Constant *C = dyn_cast<Constant>(CI->getArgOperand(0))) return C->isNullValue() || isa<UndefValue>(C); - if (CallSite CS = CallSite(I)) - if (isMathLibCallNoop(CS, TLI)) + if (auto *Call = dyn_cast<CallBase>(I)) + if (isMathLibCallNoop(Call, TLI)) return true; return false; @@ -430,7 +434,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, bool llvm::RecursivelyDeleteTriviallyDeadInstructions( Value *V, const TargetLibraryInfo *TLI, MemorySSAUpdater *MSSAU) { Instruction *I = dyn_cast<Instruction>(V); - if (!I || !I->use_empty() || !isInstructionTriviallyDead(I, TLI)) + if (!I || !isInstructionTriviallyDead(I, TLI)) return false; SmallVector<Instruction*, 16> DeadInsts; @@ -665,7 +669,7 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, if (PhiIt != OldPhiIt) PhiIt = &BB->front(); } if (DTU) - DTU->deleteEdgeRelaxed(Pred, BB); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, Pred, BB}}); } /// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its @@ -734,7 +738,7 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, isa<UnreachableInst>(PredBB->getTerminator()) && "The successor list of PredBB isn't empty before " "applying corresponding DTU updates."); - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); DTU->deleteBB(PredBB); // Recalculation of DomTree is needed when updating a forward DomTree and // the Entry BB is replaced. @@ -997,6 +1001,18 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, } } + // We cannot fold the block if it's a branch to an already present callbr + // successor because that creates duplicate successors. + for (auto I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { + if (auto *CBI = dyn_cast<CallBrInst>((*I)->getTerminator())) { + if (Succ == CBI->getDefaultDest()) + return false; + for (unsigned i = 0, e = CBI->getNumIndirectDests(); i != e; ++i) + if (Succ == CBI->getIndirectDest(i)) + return false; + } + } + LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); SmallVector<DominatorTree::UpdateType, 32> Updates; @@ -1064,7 +1080,7 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, "applying corresponding DTU updates."); if (DTU) { - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); DTU->deleteBB(BB); } else { BB->eraseFromParent(); // Delete the old basic block. @@ -1272,6 +1288,19 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { return false; } +/// Produce a DebugLoc to use for each dbg.declare/inst pair that are promoted +/// to a dbg.value. Because no machine insts can come from debug intrinsics, +/// only the scope and inlinedAt is significant. Zero line numbers are used in +/// case this DebugLoc leaks into any adjacent instructions. +static DebugLoc getDebugValueLoc(DbgVariableIntrinsic *DII, Instruction *Src) { + // Original dbg.declare must have a location. + DebugLoc DeclareLoc = DII->getDebugLoc(); + MDNode *Scope = DeclareLoc.getScope(); + DILocation *InlinedAt = DeclareLoc.getInlinedAt(); + // Produce an unknown location with the correct scope / inlinedAt fields. + return DebugLoc::get(0, 0, Scope, InlinedAt); +} + /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, @@ -1280,9 +1309,11 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, auto *DIVar = DII->getVariable(); assert(DIVar && "Missing variable"); auto *DIExpr = DII->getExpression(); - Value *DV = SI->getOperand(0); + Value *DV = SI->getValueOperand(); + + DebugLoc NewLoc = getDebugValueLoc(DII, SI); - if (!valueCoversEntireFragment(SI->getValueOperand()->getType(), DII)) { + if (!valueCoversEntireFragment(DV->getType(), DII)) { // FIXME: If storing to a part of the variable described by the dbg.declare, // then we want to insert a dbg.value for the corresponding fragment. LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " @@ -1292,14 +1323,12 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // know nothing about the variable's content. DV = UndefValue::get(DV->getType()); if (!LdStHasDebugValue(DIVar, DIExpr, SI)) - Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, DII->getDebugLoc(), - SI); + Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); return; } if (!LdStHasDebugValue(DIVar, DIExpr, SI)) - Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, DII->getDebugLoc(), - SI); + Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value @@ -1322,12 +1351,14 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, return; } + DebugLoc NewLoc = getDebugValueLoc(DII, nullptr); + // We are now tracking the loaded value instead of the address. In the // future if multi-location support is added to the IR, it might be // preferable to keep tracking both the loaded value and the original // address in case the alloca can not be elided. Instruction *DbgValue = Builder.insertDbgValueIntrinsic( - LI, DIVar, DIExpr, DII->getDebugLoc(), (Instruction *)nullptr); + LI, DIVar, DIExpr, NewLoc, (Instruction *)nullptr); DbgValue->insertAfter(LI); } @@ -1354,12 +1385,13 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, BasicBlock *BB = APN->getParent(); auto InsertionPt = BB->getFirstInsertionPt(); + DebugLoc NewLoc = getDebugValueLoc(DII, nullptr); + // The block may be a catchswitch block, which does not have a valid // insertion point. // FIXME: Insert dbg.value markers in the successors when appropriate. if (InsertionPt != BB->end()) - Builder.insertDbgValueIntrinsic(APN, DIVar, DIExpr, DII->getDebugLoc(), - &*InsertionPt); + Builder.insertDbgValueIntrinsic(APN, DIVar, DIExpr, NewLoc, &*InsertionPt); } /// Determine whether this alloca is either a VLA or an array. @@ -1414,10 +1446,11 @@ bool llvm::LowerDbgDeclare(Function &F) { // This is a call by-value or some other instruction that takes a // pointer to the variable. Insert a *value* intrinsic that describes // the variable by dereferencing the alloca. + DebugLoc NewLoc = getDebugValueLoc(DDI, nullptr); auto *DerefExpr = DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref); - DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr, - DDI->getDebugLoc(), CI); + DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr, NewLoc, + CI); } } DDI->eraseFromParent(); @@ -1519,14 +1552,14 @@ void llvm::findDbgUsers(SmallVectorImpl<DbgVariableIntrinsic *> &DbgUsers, bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Instruction *InsertBefore, DIBuilder &Builder, - bool DerefBefore, int Offset, bool DerefAfter) { + uint8_t DIExprFlags, int Offset) { auto DbgAddrs = FindDbgAddrUses(Address); for (DbgVariableIntrinsic *DII : DbgAddrs) { DebugLoc Loc = DII->getDebugLoc(); auto *DIVar = DII->getVariable(); auto *DIExpr = DII->getExpression(); assert(DIVar && "Missing variable"); - DIExpr = DIExpression::prepend(DIExpr, DerefBefore, Offset, DerefAfter); + DIExpr = DIExpression::prepend(DIExpr, DIExprFlags, Offset); // Insert llvm.dbg.declare immediately before InsertBefore, and remove old // llvm.dbg.declare. Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, InsertBefore); @@ -1538,10 +1571,10 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, } bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, - DIBuilder &Builder, bool DerefBefore, - int Offset, bool DerefAfter) { + DIBuilder &Builder, uint8_t DIExprFlags, + int Offset) { return replaceDbgDeclare(AI, NewAllocaAddress, AI->getNextNode(), Builder, - DerefBefore, Offset, DerefAfter); + DIExprFlags, Offset); } static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, @@ -1594,120 +1627,119 @@ bool llvm::salvageDebugInfo(Instruction &I) { if (DbgUsers.empty()) return false; - auto &M = *I.getModule(); - auto &DL = M.getDataLayout(); + return salvageDebugInfoForDbgValues(I, DbgUsers); +} + +bool llvm::salvageDebugInfoForDbgValues( + Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers) { auto &Ctx = I.getContext(); auto wrapMD = [&](Value *V) { return wrapValueInMetadata(Ctx, V); }; - auto doSalvage = [&](DbgVariableIntrinsic *DII, SmallVectorImpl<uint64_t> &Ops) { - auto *DIExpr = DII->getExpression(); - if (!Ops.empty()) { - // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they - // are implicitly pointing out the value as a DWARF memory location - // description. - bool WithStackValue = isa<DbgValueInst>(DII); - DIExpr = DIExpression::prependOpcodes(DIExpr, Ops, WithStackValue); - } + for (auto *DII : DbgUsers) { + // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they + // are implicitly pointing out the value as a DWARF memory location + // description. + bool StackValue = isa<DbgValueInst>(DII); + + DIExpression *DIExpr = + salvageDebugInfoImpl(I, DII->getExpression(), StackValue); + + // salvageDebugInfoImpl should fail on examining the first element of + // DbgUsers, or none of them. + if (!DIExpr) + return false; + DII->setOperand(0, wrapMD(I.getOperand(0))); DII->setOperand(2, MetadataAsValue::get(Ctx, DIExpr)); LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); + } + + return true; +} + +DIExpression *llvm::salvageDebugInfoImpl(Instruction &I, + DIExpression *SrcDIExpr, + bool WithStackValue) { + auto &M = *I.getModule(); + auto &DL = M.getDataLayout(); + + // Apply a vector of opcodes to the source DIExpression. + auto doSalvage = [&](SmallVectorImpl<uint64_t> &Ops) -> DIExpression * { + DIExpression *DIExpr = SrcDIExpr; + if (!Ops.empty()) { + DIExpr = DIExpression::prependOpcodes(DIExpr, Ops, WithStackValue); + } + return DIExpr; }; - auto applyOffset = [&](DbgVariableIntrinsic *DII, uint64_t Offset) { + // Apply the given offset to the source DIExpression. + auto applyOffset = [&](uint64_t Offset) -> DIExpression * { SmallVector<uint64_t, 8> Ops; DIExpression::appendOffset(Ops, Offset); - doSalvage(DII, Ops); + return doSalvage(Ops); }; - auto applyOps = [&](DbgVariableIntrinsic *DII, - std::initializer_list<uint64_t> Opcodes) { + // initializer-list helper for applying operators to the source DIExpression. + auto applyOps = + [&](std::initializer_list<uint64_t> Opcodes) -> DIExpression * { SmallVector<uint64_t, 8> Ops(Opcodes); - doSalvage(DII, Ops); + return doSalvage(Ops); }; if (auto *CI = dyn_cast<CastInst>(&I)) { - if (!CI->isNoopCast(DL)) - return false; - - // No-op casts are irrelevant for debug info. - MetadataAsValue *CastSrc = wrapMD(I.getOperand(0)); - for (auto *DII : DbgUsers) { - DII->setOperand(0, CastSrc); - LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); - } - return true; + // No-op casts and zexts are irrelevant for debug info. + if (CI->isNoopCast(DL) || isa<ZExtInst>(&I)) + return SrcDIExpr; + return nullptr; } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { unsigned BitWidth = M.getDataLayout().getIndexSizeInBits(GEP->getPointerAddressSpace()); - // Rewrite a constant GEP into a DIExpression. Since we are performing - // arithmetic to compute the variable's *value* in the DIExpression, we - // need to mark the expression with a DW_OP_stack_value. + // Rewrite a constant GEP into a DIExpression. APInt Offset(BitWidth, 0); - if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) - for (auto *DII : DbgUsers) - applyOffset(DII, Offset.getSExtValue()); - return true; + if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) { + return applyOffset(Offset.getSExtValue()); + } else { + return nullptr; + } } else if (auto *BI = dyn_cast<BinaryOperator>(&I)) { // Rewrite binary operations with constant integer operands. auto *ConstInt = dyn_cast<ConstantInt>(I.getOperand(1)); if (!ConstInt || ConstInt->getBitWidth() > 64) - return false; + return nullptr; uint64_t Val = ConstInt->getSExtValue(); - for (auto *DII : DbgUsers) { - switch (BI->getOpcode()) { - case Instruction::Add: - applyOffset(DII, Val); - break; - case Instruction::Sub: - applyOffset(DII, -int64_t(Val)); - break; - case Instruction::Mul: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_mul}); - break; - case Instruction::SDiv: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_div}); - break; - case Instruction::SRem: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_mod}); - break; - case Instruction::Or: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_or}); - break; - case Instruction::And: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_and}); - break; - case Instruction::Xor: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_xor}); - break; - case Instruction::Shl: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shl}); - break; - case Instruction::LShr: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shr}); - break; - case Instruction::AShr: - applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shra}); - break; - default: - // TODO: Salvage constants from each kind of binop we know about. - return false; - } + switch (BI->getOpcode()) { + case Instruction::Add: + return applyOffset(Val); + case Instruction::Sub: + return applyOffset(-int64_t(Val)); + case Instruction::Mul: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_mul}); + case Instruction::SDiv: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_div}); + case Instruction::SRem: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_mod}); + case Instruction::Or: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_or}); + case Instruction::And: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_and}); + case Instruction::Xor: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_xor}); + case Instruction::Shl: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_shl}); + case Instruction::LShr: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_shr}); + case Instruction::AShr: + return applyOps({dwarf::DW_OP_constu, Val, dwarf::DW_OP_shra}); + default: + // TODO: Salvage constants from each kind of binop we know about. + return nullptr; } - return true; - } else if (isa<LoadInst>(&I)) { - MetadataAsValue *AddrMD = wrapMD(I.getOperand(0)); - for (auto *DII : DbgUsers) { - // Rewrite the load into DW_OP_deref. - auto *DIExpr = DII->getExpression(); - DIExpr = DIExpression::prepend(DIExpr, DIExpression::WithDeref); - DII->setOperand(0, AddrMD); - DII->setOperand(2, MetadataAsValue::get(Ctx, DIExpr)); - LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); - } - return true; + // *Not* to do: we should not attempt to salvage load instructions, + // because the validity and lifetime of a dbg.value containing + // DW_OP_deref becomes difficult to analyze. See PR40628 for examples. } - return false; + return nullptr; } /// A replacement for a dbg.value expression. @@ -1849,21 +1881,10 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, return None; bool Signed = *Signedness == DIBasicType::Signedness::Signed; - - if (!Signed) { - // In the unsigned case, assume that a debugger will initialize the - // high bits to 0 and do a no-op conversion. - return Identity(DII); - } else { - // In the signed case, the high bits are given by sign extension, i.e: - // (To >> (ToBits - 1)) * ((2 ^ FromBits) - 1) - // Calculate the high bits and OR them together with the low bits. - SmallVector<uint64_t, 8> Ops({dwarf::DW_OP_dup, dwarf::DW_OP_constu, - (ToBits - 1), dwarf::DW_OP_shr, - dwarf::DW_OP_lit0, dwarf::DW_OP_not, - dwarf::DW_OP_mul, dwarf::DW_OP_or}); - return DIExpression::appendToStack(DII.getExpression(), Ops); - } + dwarf::TypeKind TK = Signed ? dwarf::DW_ATE_signed : dwarf::DW_ATE_unsigned; + SmallVector<uint64_t, 8> Ops({dwarf::DW_OP_LLVM_convert, ToBits, TK, + dwarf::DW_OP_LLVM_convert, FromBits, TK}); + return DIExpression::appendToStack(DII.getExpression(), Ops); }; return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt); } @@ -1894,10 +1915,14 @@ unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { } unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, - bool PreserveLCSSA, DomTreeUpdater *DTU) { + bool PreserveLCSSA, DomTreeUpdater *DTU, + MemorySSAUpdater *MSSAU) { BasicBlock *BB = I->getParent(); std::vector <DominatorTree::UpdateType> Updates; + if (MSSAU) + MSSAU->changeToUnreachable(I); + // Loop over all of the successors, removing BB's entry from any PHI // nodes. if (DTU) @@ -1928,7 +1953,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, ++NumInstrsRemoved; } if (DTU) - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); return NumInstrsRemoved; } @@ -1937,8 +1962,8 @@ static void changeToCall(InvokeInst *II, DomTreeUpdater *DTU = nullptr) { SmallVector<Value*, 8> Args(II->arg_begin(), II->arg_end()); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); - CallInst *NewCall = CallInst::Create(II->getCalledValue(), Args, OpBundles, - "", II); + CallInst *NewCall = CallInst::Create( + II->getFunctionType(), II->getCalledValue(), Args, OpBundles, "", II); NewCall->takeName(II); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); @@ -1956,7 +1981,7 @@ static void changeToCall(InvokeInst *II, DomTreeUpdater *DTU = nullptr) { UnwindDestBB->removePredecessor(BB); II->eraseFromParent(); if (DTU) - DTU->deleteEdgeRelaxed(BB, UnwindDestBB); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, UnwindDestBB}}); } BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, @@ -1981,8 +2006,9 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, // can potentially be avoided with a cleverer API design that we do not have // as of this time. - InvokeInst *II = InvokeInst::Create(CI->getCalledValue(), Split, UnwindEdge, - InvokeArgs, OpBundles, CI->getName(), BB); + InvokeInst *II = + InvokeInst::Create(CI->getFunctionType(), CI->getCalledValue(), Split, + UnwindEdge, InvokeArgs, OpBundles, CI->getName(), BB); II->setDebugLoc(CI->getDebugLoc()); II->setCallingConv(CI->getCallingConv()); II->setAttributes(CI->getAttributes()); @@ -2052,7 +2078,7 @@ static bool markAliveBlocks(Function &F, Changed = true; break; } - if (CI->doesNotReturn()) { + if (CI->doesNotReturn() && !CI->isMustTailCall()) { // If we found a call to a no-return function, insert an unreachable // instruction after it. Make sure there isn't *already* one there // though. @@ -2102,7 +2128,8 @@ static bool markAliveBlocks(Function &F, UnwindDestBB->removePredecessor(II->getParent()); II->eraseFromParent(); if (DTU) - DTU->deleteEdgeRelaxed(BB, UnwindDestBB); + DTU->applyUpdatesPermissive( + {{DominatorTree::Delete, BB, UnwindDestBB}}); } else changeToCall(II, DTU); Changed = true; @@ -2191,7 +2218,7 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { TI->replaceAllUsesWith(NewTI); TI->eraseFromParent(); if (DTU) - DTU->deleteEdgeRelaxed(BB, UnwindDest); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, UnwindDest}}); } /// removeUnreachableBlocks - Remove blocks that are not reachable, even @@ -2211,7 +2238,7 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, assert(Reachable.size() < F.size()); NumRemoved += F.size()-Reachable.size(); - SmallPtrSet<BasicBlock *, 16> DeadBlockSet; + SmallSetVector<BasicBlock *, 8> DeadBlockSet; for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ++I) { auto *BB = &*I; if (Reachable.count(BB)) @@ -2256,7 +2283,7 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, } if (DTU) { - DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->applyUpdatesPermissive(Updates); bool Deleted = false; for (auto *BB : DeadBlockSet) { if (DTU->isBBPendingDeletion(BB)) @@ -2450,12 +2477,12 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, return ::replaceDominatedUsesWith(From, To, BB, ProperlyDominates); } -bool llvm::callsGCLeafFunction(ImmutableCallSite CS, +bool llvm::callsGCLeafFunction(const CallBase *Call, const TargetLibraryInfo &TLI) { // Check if the function is specifically marked as a gc leaf function. - if (CS.hasFnAttr("gc-leaf-function")) + if (Call->hasFnAttr("gc-leaf-function")) return true; - if (const Function *F = CS.getCalledFunction()) { + if (const Function *F = Call->getCalledFunction()) { if (F->hasFnAttribute("gc-leaf-function")) return true; @@ -2469,7 +2496,7 @@ bool llvm::callsGCLeafFunction(ImmutableCallSite CS, // marked as 'gc-leaf-function.' All available Libcalls are // GC-leaf. LibFunc LF; - if (TLI.getLibFunc(CS, LF)) { + if (TLI.getLibFunc(ImmutableCallSite(Call), LF)) { return TLI.has(LF); } @@ -2530,13 +2557,13 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, BasicBlock *BB) { // Since we are moving the instructions out of its basic block, we do not // retain their original debug locations (DILocations) and debug intrinsic - // instructions (dbg.values). + // instructions. // // Doing so would degrade the debugging experience and adversely affect the // accuracy of profiling information. // // Currently, when hoisting the instructions, we take the following actions: - // - Remove their dbg.values. + // - Remove their debug intrinsic instructions. // - Set their debug locations to the values from the insertion point. // // As per PR39141 (comment #8), the more fundamental reason why the dbg.values @@ -2554,7 +2581,7 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, I->dropUnknownNonDebugMetadata(); if (I->isUsedByMetadata()) dropDebugUsers(*I); - if (isa<DbgVariableIntrinsic>(I)) { + if (isa<DbgInfoIntrinsic>(I)) { // Remove DbgInfo Intrinsics. II = I->eraseFromParent(); continue; @@ -2613,7 +2640,7 @@ struct BitPart { /// does not invalidate internal references (std::map instead of DenseMap). static const Optional<BitPart> & collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, - std::map<Value *, Optional<BitPart>> &BPS) { + std::map<Value *, Optional<BitPart>> &BPS, int Depth) { auto I = BPS.find(V); if (I != BPS.end()) return I->second; @@ -2621,13 +2648,19 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, auto &Result = BPS[V] = None; auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); + // Prevent stack overflow by limiting the recursion depth + if (Depth == BitPartRecursionMaxDepth) { + LLVM_DEBUG(dbgs() << "collectBitParts max recursion depth reached.\n"); + return Result; + } + if (Instruction *I = dyn_cast<Instruction>(V)) { // If this is an or instruction, it may be an inner node of the bswap. if (I->getOpcode() == Instruction::Or) { auto &A = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); auto &B = collectBitParts(I->getOperand(1), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!A || !B) return Result; @@ -2660,7 +2693,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, return Result; auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; Result = Res; @@ -2692,7 +2725,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, return Result; auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; Result = Res; @@ -2707,7 +2740,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If this is a zext instruction zero extend the result. if (I->getOpcode() == Instruction::ZExt) { auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; @@ -2769,7 +2802,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // Try to find all the pieces corresponding to the bswap. std::map<Value *, Optional<BitPart>> BPS; - auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS); + auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0); if (!Res) return false; auto &BitProvenance = Res->Provenance; @@ -2883,3 +2916,41 @@ bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { return true; } } + +using AllocaForValueMapTy = DenseMap<Value *, AllocaInst *>; +AllocaInst *llvm::findAllocaForValue(Value *V, + AllocaForValueMapTy &AllocaForValue) { + if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) + return AI; + // See if we've already calculated (or started to calculate) alloca for a + // given value. + AllocaForValueMapTy::iterator I = AllocaForValue.find(V); + if (I != AllocaForValue.end()) + return I->second; + // Store 0 while we're calculating alloca for value V to avoid + // infinite recursion if the value references itself. + AllocaForValue[V] = nullptr; + AllocaInst *Res = nullptr; + if (CastInst *CI = dyn_cast<CastInst>(V)) + Res = findAllocaForValue(CI->getOperand(0), AllocaForValue); + else if (PHINode *PN = dyn_cast<PHINode>(V)) { + for (Value *IncValue : PN->incoming_values()) { + // Allow self-referencing phi-nodes. + if (IncValue == PN) + continue; + AllocaInst *IncValueAI = findAllocaForValue(IncValue, AllocaForValue); + // AI for incoming values should exist and should all be equal. + if (IncValueAI == nullptr || (Res != nullptr && IncValueAI != Res)) + return nullptr; + Res = IncValueAI; + } + } else if (GetElementPtrInst *EP = dyn_cast<GetElementPtrInst>(V)) { + Res = findAllocaForValue(EP->getPointerOperand(), AllocaForValue); + } else { + LLVM_DEBUG(dbgs() << "Alloca search cancelled on unknown instruction: " + << *V << "\n"); + } + if (Res) + AllocaForValue[V] = Res; + return Res; +} diff --git a/lib/Transforms/Utils/LoopRotationUtils.cpp b/lib/Transforms/Utils/LoopRotationUtils.cpp index 41f14a834617..37389a695b45 100644 --- a/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -1,9 +1,8 @@ //===----------------- LoopRotationUtils.cpp -----------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -17,6 +16,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" @@ -28,7 +28,6 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" @@ -296,7 +295,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Begin by walking OrigHeader and populating ValueMap with an entry for // each Instruction. BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); - ValueToValueMapTy ValueMap; + ValueToValueMapTy ValueMap, ValueMapMSSA; // For PHI nodes, the value available in OldPreHeader is just the // incoming value from OldPreHeader. @@ -375,6 +374,9 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { if (auto *II = dyn_cast<IntrinsicInst>(C)) if (II->getIntrinsicID() == Intrinsic::assume) AC->registerAssumption(II); + // MemorySSA cares whether the cloned instruction was inserted or not, and + // not whether it can be remapped to a simplified value. + ValueMapMSSA[Inst] = C; } } @@ -392,10 +394,11 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { LoopEntryBranch->eraseFromParent(); // Update MemorySSA before the rewrite call below changes the 1:1 - // instruction:cloned_instruction_or_value mapping in ValueMap. + // instruction:cloned_instruction_or_value mapping. if (MSSAU) { - ValueMap[OrigHeader] = OrigPreheader; - MSSAU->updateForClonedBlockIntoPred(OrigHeader, OrigPreheader, ValueMap); + ValueMapMSSA[OrigHeader] = OrigPreheader; + MSSAU->updateForClonedBlockIntoPred(OrigHeader, OrigPreheader, + ValueMapMSSA); } SmallVector<PHINode*, 2> InsertedPHIs; @@ -463,9 +466,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { for (BasicBlock *ExitPred : ExitPreds) { // We only need to split loop exit edges. Loop *PredLoop = LI->getLoopFor(ExitPred); - if (!PredLoop || PredLoop->contains(Exit)) - continue; - if (isa<IndirectBrInst>(ExitPred->getTerminator())) + if (!PredLoop || PredLoop->contains(Exit) || + ExitPred->getTerminator()->isIndirectTerminator()) continue; SplitLatchEdge |= L->getLoopLatch() == ExitPred; BasicBlock *ExitSplit = SplitCriticalEdge( diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index 380f4fca54d9..7e6da02d5707 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -1,9 +1,8 @@ //===- LoopSimplify.cpp - Loop Canonicalization Pass ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -28,6 +27,9 @@ // to transform the loop and make these guarantees. Client code should check // that these conditions are true before relying on them. // +// Similar complications arise from callbr instructions, particularly in +// asm-goto where blockaddress expressions are used. +// // Note that the simplifycfg pass will clean up blocks which are split out but // end up being unnecessary, so usage of this pass should not pessimize // generated code. @@ -46,13 +48,15 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -67,6 +71,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -115,7 +120,8 @@ static void placeSplitBlockCarefully(BasicBlock *NewBB, /// preheader insertion and analysis updating. /// BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA) { + LoopInfo *LI, MemorySSAUpdater *MSSAU, + bool PreserveLCSSA) { BasicBlock *Header = L->getHeader(); // Compute the set of predecessors of the loop that are not in the loop. @@ -124,10 +130,11 @@ BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, PI != PE; ++PI) { BasicBlock *P = *PI; if (!L->contains(P)) { // Coming in from outside the loop? - // If the loop is branched to from an indirect branch, we won't + // If the loop is branched to from an indirect terminator, we won't // be able to fully transform the loop, because it prohibits // edge splitting. - if (isa<IndirectBrInst>(P->getTerminator())) return nullptr; + if (P->getTerminator()->isIndirectTerminator()) + return nullptr; // Keep track of it. OutsideBlocks.push_back(P); @@ -137,7 +144,7 @@ BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, // Split out the loop pre-header. BasicBlock *PreheaderBB; PreheaderBB = SplitBlockPredecessors(Header, OutsideBlocks, ".preheader", DT, - LI, nullptr, PreserveLCSSA); + LI, MSSAU, PreserveLCSSA); if (!PreheaderBB) return nullptr; @@ -217,7 +224,7 @@ static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, bool PreserveLCSSA, - AssumptionCache *AC) { + AssumptionCache *AC, MemorySSAUpdater *MSSAU) { // Don't try to separate loops without a preheader. if (!Preheader) return nullptr; @@ -236,8 +243,8 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { if (PN->getIncomingValue(i) != PN || !L->contains(PN->getIncomingBlock(i))) { - // We can't split indirectbr edges. - if (isa<IndirectBrInst>(PN->getIncomingBlock(i)->getTerminator())) + // We can't split indirect control flow edges. + if (PN->getIncomingBlock(i)->getTerminator()->isIndirectTerminator()) return nullptr; OuterLoopPreds.push_back(PN->getIncomingBlock(i)); } @@ -251,7 +258,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, SE->forgetLoop(L); BasicBlock *NewBB = SplitBlockPredecessors(Header, OuterLoopPreds, ".outer", - DT, LI, nullptr, PreserveLCSSA); + DT, LI, MSSAU, PreserveLCSSA); // Make sure that NewBB is put someplace intelligent, which doesn't mess up // code layout too horribly. @@ -314,7 +321,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, // Split edges to exit blocks from the inner loop, if they emerged in the // process of separating the outer one. - formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA); + formDedicatedExitBlocks(L, DT, LI, MSSAU, PreserveLCSSA); if (PreserveLCSSA) { // Fix LCSSA form for L. Some values, which previously were only used inside @@ -339,7 +346,8 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, /// and have that block branch to the loop header. This ensures that loops /// have exactly one backedge. static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, - DominatorTree *DT, LoopInfo *LI) { + DominatorTree *DT, LoopInfo *LI, + MemorySSAUpdater *MSSAU) { assert(L->getNumBackEdges() > 1 && "Must have > 1 backedge!"); // Get information about the loop @@ -358,8 +366,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, for (pred_iterator I = pred_begin(Header), E = pred_end(Header); I != E; ++I){ BasicBlock *P = *I; - // Indirectbr edges cannot be split, so we must fail if we find one. - if (isa<IndirectBrInst>(P->getTerminator())) + // Indirect edges cannot be split, so we must fail if we find one. + if (P->getTerminator()->isIndirectTerminator()) return nullptr; if (P != Preheader) BackedgeBlocks.push_back(P); @@ -439,9 +447,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, if (!LoopMD) LoopMD = TI->getMetadata(LoopMDKind); TI->setMetadata(LoopMDKind, nullptr); - for (unsigned Op = 0, e = TI->getNumSuccessors(); Op != e; ++Op) - if (TI->getSuccessor(Op) == Header) - TI->setSuccessor(Op, BEBlock); + TI->replaceSuccessorWith(Header, BEBlock); } BEBlock->getTerminator()->setMetadata(LoopMDKind, LoopMD); @@ -454,6 +460,10 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, // Update dominator information DT->splitBlock(BEBlock); + if (MSSAU) + MSSAU->updatePhisWhenInsertingUniqueBackedgeBlock(Header, Preheader, + BEBlock); + return BEBlock; } @@ -461,8 +471,11 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - bool PreserveLCSSA) { + MemorySSAUpdater *MSSAU, bool PreserveLCSSA) { bool Changed = false; + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + ReprocessLoop: // Check to see that no blocks (other than the header) in this loop have @@ -489,11 +502,15 @@ ReprocessLoop: // Zap the dead pred's terminator and replace it with unreachable. Instruction *TI = P->getTerminator(); - changeToUnreachable(TI, /*UseLLVMTrap=*/false, PreserveLCSSA); + changeToUnreachable(TI, /*UseLLVMTrap=*/false, PreserveLCSSA, + /*DTU=*/nullptr, MSSAU); Changed = true; } } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // If there are exiting blocks with branches on undef, resolve the undef in // the direction which will exit the loop. This will help simplify loop // trip count computations. @@ -518,7 +535,7 @@ ReprocessLoop: // Does the loop already have a preheader? If so, don't insert one. BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { - Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); + Preheader = InsertPreheaderForLoop(L, DT, LI, MSSAU, PreserveLCSSA); if (Preheader) Changed = true; } @@ -527,9 +544,12 @@ ReprocessLoop: // predecessors that are inside of the loop. This check guarantees that the // loop preheader/header will dominate the exit blocks. If the exit block has // predecessors from outside of the loop, split the edge now. - if (formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA)) + if (formDedicatedExitBlocks(L, DT, LI, MSSAU, PreserveLCSSA)) Changed = true; + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // If the header has more than two predecessors at this point (from the // preheader and from multiple backedges), we must adjust the loop. BasicBlock *LoopLatch = L->getLoopLatch(); @@ -538,8 +558,8 @@ ReprocessLoop: // this for loops with a giant number of backedges, just factor them into a // common backedge instead. if (L->getNumBackEdges() < 8) { - if (Loop *OuterL = - separateNestedLoop(L, Preheader, DT, LI, SE, PreserveLCSSA, AC)) { + if (Loop *OuterL = separateNestedLoop(L, Preheader, DT, LI, SE, + PreserveLCSSA, AC, MSSAU)) { ++NumNested; // Enqueue the outer loop as it should be processed next in our // depth-first nest walk. @@ -556,11 +576,14 @@ ReprocessLoop: // If we either couldn't, or didn't want to, identify nesting of the loops, // insert a new block that all backedges target, then make it jump to the // loop header. - LoopLatch = insertUniqueBackedgeBlock(L, Preheader, DT, LI); + LoopLatch = insertUniqueBackedgeBlock(L, Preheader, DT, LI, MSSAU); if (LoopLatch) Changed = true; } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); // Scan over the PHI nodes in the loop header. Since they now have only two @@ -618,9 +641,9 @@ ReprocessLoop: Instruction *Inst = &*I++; if (Inst == CI) continue; - if (!L->makeLoopInvariant(Inst, AnyInvariant, - Preheader ? Preheader->getTerminator() - : nullptr)) { + if (!L->makeLoopInvariant( + Inst, AnyInvariant, + Preheader ? Preheader->getTerminator() : nullptr, MSSAU)) { AllInvariant = false; break; } @@ -637,7 +660,7 @@ ReprocessLoop: // The block has now been cleared of all instructions except for // a comparison and a conditional branch. SimplifyCFG may be able // to fold it now. - if (!FoldBranchToCommonDest(BI)) + if (!FoldBranchToCommonDest(BI, MSSAU)) continue; // Success. The block is now dead, so remove it from the loop, @@ -657,11 +680,16 @@ ReprocessLoop: DT->changeImmediateDominator(Child, Node->getIDom()); } DT->eraseNode(ExitingBlock); + if (MSSAU) { + SmallSetVector<BasicBlock *, 8> ExitBlockSet; + ExitBlockSet.insert(ExitingBlock); + MSSAU->removeBlocks(ExitBlockSet); + } BI->getSuccessor(0)->removePredecessor( - ExitingBlock, /* DontDeleteUselessPHIs */ PreserveLCSSA); + ExitingBlock, /* KeepOneInputPHIs */ PreserveLCSSA); BI->getSuccessor(1)->removePredecessor( - ExitingBlock, /* DontDeleteUselessPHIs */ PreserveLCSSA); + ExitingBlock, /* KeepOneInputPHIs */ PreserveLCSSA); ExitingBlock->eraseFromParent(); } } @@ -672,12 +700,15 @@ ReprocessLoop: if (Changed && SE) SE->forgetTopmostLoop(L); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + return Changed; } bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - bool PreserveLCSSA) { + MemorySSAUpdater *MSSAU, bool PreserveLCSSA) { bool Changed = false; #ifndef NDEBUG @@ -705,7 +736,7 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, while (!Worklist.empty()) Changed |= simplifyOneLoop(Worklist.pop_back_val(), Worklist, DT, LI, SE, - AC, PreserveLCSSA); + AC, MSSAU, PreserveLCSSA); return Changed; } @@ -737,6 +768,9 @@ namespace { AU.addPreservedID(LCSSAID); AU.addPreserved<DependenceAnalysisWrapperPass>(); AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. + AU.addPreserved<BranchProbabilityInfoWrapperPass>(); + if (EnableMSSALoopDependency) + AU.addPreserved<MemorySSAWrapperPass>(); } /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. @@ -768,12 +802,21 @@ bool LoopSimplify::runOnFunction(Function &F) { ScalarEvolution *SE = SEWP ? &SEWP->getSE() : nullptr; AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + MemorySSA *MSSA = nullptr; + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency) { + auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + if (MSSAAnalysis) { + MSSA = &MSSAAnalysis->getMSSA(); + MSSAU = make_unique<MemorySSAUpdater>(MSSA); + } + } bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); // Simplify each loop nest in the function. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, SE, AC, PreserveLCSSA); + Changed |= simplifyLoop(*I, DT, LI, SE, AC, MSSAU.get(), PreserveLCSSA); #ifndef NDEBUG if (PreserveLCSSA) { @@ -794,9 +837,10 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); // Note that we don't preserve LCSSA in the new PM, if you need it run LCSSA - // after simplifying the loops. + // after simplifying the loops. MemorySSA is not preserved either. for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) - Changed |= simplifyLoop(*I, DT, LI, SE, AC, /*PreserveLCSSA*/ false); + Changed |= + simplifyLoop(*I, DT, LI, SE, AC, nullptr, /*PreserveLCSSA*/ false); if (!Changed) return PreservedAnalyses::all(); @@ -809,6 +853,12 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, PA.preserve<SCEVAA>(); PA.preserve<ScalarEvolutionAnalysis>(); PA.preserve<DependenceAnalysis>(); + // BPI maps conditional terminators to probabilities, LoopSimplify can insert + // blocks, but it does so only by splitting existing blocks and edges. This + // results in the interesting property that all new terminators inserted are + // unconditional branches which do not appear in BPI. All deletions are + // handled via ValueHandle callbacks w/in BPI. + PA.preserve<BranchProbabilityAnalysis>(); return PA; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index da7ed2bd1652..e39ade523714 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -1,9 +1,8 @@ //===-- UnrollLoop.cpp - Loop unrolling utilities -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -45,6 +44,8 @@ using namespace llvm; // TODO: Should these be here or in LoopUnroll? STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); +STATISTIC(NumUnrolledWithHeader, "Number of loops unrolled without a " + "conditional latch (completely or otherwise)"); static cl::opt<bool> UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, @@ -94,66 +95,6 @@ void llvm::remapInstruction(Instruction *I, ValueToValueMapTy &VMap) { } } -/// Folds a basic block into its predecessor if it only has one predecessor, and -/// that predecessor only has one successor. -/// The LoopInfo Analysis that is passed will be kept consistent. -BasicBlock *llvm::foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, - ScalarEvolution *SE, - DominatorTree *DT) { - // Merge basic blocks into their predecessor if there is only one distinct - // pred, and if there is only one distinct successor of the predecessor, and - // if there are no PHI nodes. - BasicBlock *OnlyPred = BB->getSinglePredecessor(); - if (!OnlyPred) return nullptr; - - if (OnlyPred->getTerminator()->getNumSuccessors() != 1) - return nullptr; - - LLVM_DEBUG(dbgs() << "Merging: " << BB->getName() << " into " - << OnlyPred->getName() << "\n"); - - // Resolve any PHI nodes at the start of the block. They are all - // guaranteed to have exactly one entry if they exist, unless there are - // multiple duplicate (but guaranteed to be equal) entries for the - // incoming edges. This occurs when there are multiple edges from - // OnlyPred to OnlySucc. - FoldSingleEntryPHINodes(BB); - - // Delete the unconditional branch from the predecessor... - OnlyPred->getInstList().pop_back(); - - // Make all PHI nodes that referred to BB now refer to Pred as their - // source... - BB->replaceAllUsesWith(OnlyPred); - - // Move all definitions in the successor to the predecessor... - OnlyPred->getInstList().splice(OnlyPred->end(), BB->getInstList()); - - // OldName will be valid until erased. - StringRef OldName = BB->getName(); - - // Erase the old block and update dominator info. - if (DT) - if (DomTreeNode *DTN = DT->getNode(BB)) { - DomTreeNode *PredDTN = DT->getNode(OnlyPred); - SmallVector<DomTreeNode *, 8> Children(DTN->begin(), DTN->end()); - for (auto *DI : Children) - DT->changeImmediateDominator(DI, PredDTN); - - DT->eraseNode(BB); - } - - LI->removeBlock(BB); - - // Inherit predecessor's name if it exists... - if (!OldName.empty() && !OnlyPred->hasName()) - OnlyPred->setName(OldName); - - BB->eraseFromParent(); - - return OnlyPred; -} - /// Check if unrolling created a situation where we need to insert phi nodes to /// preserve LCSSA form. /// \param Blocks is a vector of basic blocks representing unrolled loop. @@ -332,12 +273,11 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, /// /// If RemainderLoop is non-null, it will receive the remainder loop (if /// required and not fully unrolled). -LoopUnrollResult llvm::UnrollLoop( - Loop *L, unsigned Count, unsigned TripCount, bool Force, bool AllowRuntime, - bool AllowExpensiveTripCount, bool PreserveCondBr, bool PreserveOnlyFirst, - unsigned TripMultiple, unsigned PeelCount, bool UnrollRemainder, - LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, - OptimizationRemarkEmitter *ORE, bool PreserveLCSSA, Loop **RemainderLoop) { +LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, + bool PreserveLCSSA, Loop **RemainderLoop) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { @@ -357,28 +297,46 @@ LoopUnrollResult llvm::UnrollLoop( return LoopUnrollResult::Unmodified; } - // The current loop unroll pass can only unroll loops with a single latch + // The current loop unroll pass can unroll loops with a single latch or header // that's a conditional branch exiting the loop. // FIXME: The implementation can be extended to work with more complicated // cases, e.g. loops with multiple latches. BasicBlock *Header = L->getHeader(); + BranchInst *HeaderBI = dyn_cast<BranchInst>(Header->getTerminator()); BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator()); - if (!BI || BI->isUnconditional()) { - // The loop-rotate pass can be helpful to avoid this in many cases. + // FIXME: Support loops without conditional latch and multiple exiting blocks. + if (!BI || + (BI->isUnconditional() && (!HeaderBI || HeaderBI->isUnconditional() || + L->getExitingBlock() != Header))) { + LLVM_DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional " + "branch in the latch or header.\n"); + return LoopUnrollResult::Unmodified; + } + + auto CheckLatchSuccessors = [&](unsigned S1, unsigned S2) { + return BI->isConditional() && BI->getSuccessor(S1) == Header && + !L->contains(BI->getSuccessor(S2)); + }; + + // If we have a conditional latch, it must exit the loop. + if (BI && BI->isConditional() && !CheckLatchSuccessors(0, 1) && + !CheckLatchSuccessors(1, 0)) { LLVM_DEBUG( - dbgs() - << " Can't unroll; loop not terminated by a conditional branch.\n"); + dbgs() << "Can't unroll; a conditional latch must exit the loop"); return LoopUnrollResult::Unmodified; } - auto CheckSuccessors = [&](unsigned S1, unsigned S2) { - return BI->getSuccessor(S1) == Header && !L->contains(BI->getSuccessor(S2)); + auto CheckHeaderSuccessors = [&](unsigned S1, unsigned S2) { + return HeaderBI && HeaderBI->isConditional() && + L->contains(HeaderBI->getSuccessor(S1)) && + !L->contains(HeaderBI->getSuccessor(S2)); }; - if (!CheckSuccessors(0, 1) && !CheckSuccessors(1, 0)) { - LLVM_DEBUG(dbgs() << "Can't unroll; only loops with one conditional latch" - " exiting the loop can be unrolled\n"); + // If we do not have a conditional latch, the header must exit the loop. + if (BI && !BI->isConditional() && HeaderBI && HeaderBI->isConditional() && + !CheckHeaderSuccessors(0, 1) && !CheckHeaderSuccessors(1, 0)) { + LLVM_DEBUG(dbgs() << "Can't unroll; conditional header must exit the loop"); return LoopUnrollResult::Unmodified; } @@ -389,28 +347,28 @@ LoopUnrollResult llvm::UnrollLoop( return LoopUnrollResult::Unmodified; } - if (TripCount != 0) - LLVM_DEBUG(dbgs() << " Trip Count = " << TripCount << "\n"); - if (TripMultiple != 1) - LLVM_DEBUG(dbgs() << " Trip Multiple = " << TripMultiple << "\n"); + if (ULO.TripCount != 0) + LLVM_DEBUG(dbgs() << " Trip Count = " << ULO.TripCount << "\n"); + if (ULO.TripMultiple != 1) + LLVM_DEBUG(dbgs() << " Trip Multiple = " << ULO.TripMultiple << "\n"); // Effectively "DCE" unrolled iterations that are beyond the tripcount // and will never be executed. - if (TripCount != 0 && Count > TripCount) - Count = TripCount; + if (ULO.TripCount != 0 && ULO.Count > ULO.TripCount) + ULO.Count = ULO.TripCount; // Don't enter the unroll code if there is nothing to do. - if (TripCount == 0 && Count < 2 && PeelCount == 0) { + if (ULO.TripCount == 0 && ULO.Count < 2 && ULO.PeelCount == 0) { LLVM_DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); return LoopUnrollResult::Unmodified; } - assert(Count > 0); - assert(TripMultiple > 0); - assert(TripCount == 0 || TripCount % TripMultiple == 0); + assert(ULO.Count > 0); + assert(ULO.TripMultiple > 0); + assert(ULO.TripCount == 0 || ULO.TripCount % ULO.TripMultiple == 0); // Are we eliminating the loop control altogether? - bool CompletelyUnroll = Count == TripCount; + bool CompletelyUnroll = ULO.Count == ULO.TripCount; SmallVector<BasicBlock *, 4> ExitBlocks; L->getExitBlocks(ExitBlocks); std::vector<BasicBlock*> OriginalLoopBlocks = L->getBlocks(); @@ -429,24 +387,29 @@ LoopUnrollResult llvm::UnrollLoop( // We assume a run-time trip count if the compiler cannot // figure out the loop trip count and the unroll-runtime // flag is specified. - bool RuntimeTripCount = (TripCount == 0 && Count > 0 && AllowRuntime); + bool RuntimeTripCount = + (ULO.TripCount == 0 && ULO.Count > 0 && ULO.AllowRuntime); - assert((!RuntimeTripCount || !PeelCount) && + assert((!RuntimeTripCount || !ULO.PeelCount) && "Did not expect runtime trip-count unrolling " "and peeling for the same loop"); bool Peeled = false; - if (PeelCount) { - Peeled = peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); + if (ULO.PeelCount) { + Peeled = peelLoop(L, ULO.PeelCount, LI, SE, DT, AC, PreserveLCSSA); // Successful peeling may result in a change in the loop preheader/trip // counts. If we later unroll the loop, we want these to be updated. if (Peeled) { - BasicBlock *ExitingBlock = L->getExitingBlock(); + // According to our guards and profitability checks the only + // meaningful exit should be latch block. Other exits go to deopt, + // so we do not worry about them. + BasicBlock *ExitingBlock = L->getLoopLatch(); assert(ExitingBlock && "Loop without exiting block?"); + assert(L->isLoopExiting(ExitingBlock) && "Latch is not exiting?"); Preheader = L->getLoopPreheader(); - TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); - TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + ULO.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); + ULO.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); } } @@ -459,7 +422,7 @@ LoopUnrollResult llvm::UnrollLoop( for (auto &I : *BB) if (auto CS = CallSite(&I)) HasConvergent |= CS.isConvergent(); - assert((!HasConvergent || TripMultiple % Count == 0) && + assert((!HasConvergent || ULO.TripMultiple % ULO.Count == 0) && "Unroll count must divide trip multiple if loop contains a " "convergent operation."); }); @@ -468,11 +431,12 @@ LoopUnrollResult llvm::UnrollLoop( UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog : isEpilogProfitable(L); - if (RuntimeTripCount && TripMultiple % Count != 0 && - !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, - EpilogProfitability, UnrollRemainder, LI, SE, - DT, AC, PreserveLCSSA, RemainderLoop)) { - if (Force) + if (RuntimeTripCount && ULO.TripMultiple % ULO.Count != 0 && + !UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount, + EpilogProfitability, ULO.UnrollRemainder, + ULO.ForgetAllSCEV, LI, SE, DT, AC, + PreserveLCSSA, RemainderLoop)) { + if (ULO.Force) RuntimeTripCount = false; else { LLVM_DEBUG(dbgs() << "Won't unroll; remainder loop could not be " @@ -483,35 +447,35 @@ LoopUnrollResult llvm::UnrollLoop( // If we know the trip count, we know the multiple... unsigned BreakoutTrip = 0; - if (TripCount != 0) { - BreakoutTrip = TripCount % Count; - TripMultiple = 0; + if (ULO.TripCount != 0) { + BreakoutTrip = ULO.TripCount % ULO.Count; + ULO.TripMultiple = 0; } else { // Figure out what multiple to use. - BreakoutTrip = TripMultiple = - (unsigned)GreatestCommonDivisor64(Count, TripMultiple); + BreakoutTrip = ULO.TripMultiple = + (unsigned)GreatestCommonDivisor64(ULO.Count, ULO.TripMultiple); } using namespace ore; // Report the unrolling decision. if (CompletelyUnroll) { LLVM_DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() - << " with trip count " << TripCount << "!\n"); + << " with trip count " << ULO.TripCount << "!\n"); if (ORE) ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), L->getHeader()) << "completely unrolled loop with " - << NV("UnrollCount", TripCount) << " iterations"; + << NV("UnrollCount", ULO.TripCount) << " iterations"; }); - } else if (PeelCount) { + } else if (ULO.PeelCount) { LLVM_DEBUG(dbgs() << "PEELING loop %" << Header->getName() - << " with iteration count " << PeelCount << "!\n"); + << " with iteration count " << ULO.PeelCount << "!\n"); if (ORE) ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), L->getHeader()) - << " peeled loop by " << NV("PeelCount", PeelCount) + << " peeled loop by " << NV("PeelCount", ULO.PeelCount) << " iterations"; }); } else { @@ -519,24 +483,25 @@ LoopUnrollResult llvm::UnrollLoop( OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), L->getHeader()); return Diag << "unrolled loop by a factor of " - << NV("UnrollCount", Count); + << NV("UnrollCount", ULO.Count); }; LLVM_DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() << " by " - << Count); - if (TripMultiple == 0 || BreakoutTrip != TripMultiple) { + << ULO.Count); + if (ULO.TripMultiple == 0 || BreakoutTrip != ULO.TripMultiple) { LLVM_DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); if (ORE) ORE->emit([&]() { return DiagBuilder() << " with a breakout at trip " << NV("BreakoutTrip", BreakoutTrip); }); - } else if (TripMultiple != 1) { - LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); + } else if (ULO.TripMultiple != 1) { + LLVM_DEBUG(dbgs() << " with " << ULO.TripMultiple << " trips per branch"); if (ORE) ORE->emit([&]() { - return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple) - << " trips per branch"; + return DiagBuilder() + << " with " << NV("TripMultiple", ULO.TripMultiple) + << " trips per branch"; }); } else if (RuntimeTripCount) { LLVM_DEBUG(dbgs() << " with run-time trip count"); @@ -555,11 +520,24 @@ LoopUnrollResult llvm::UnrollLoop( // and if something changes inside them then any of outer loops may also // change. When we forget outermost loop, we also forget all contained loops // and this is what we need here. - if (SE) - SE->forgetTopmostLoop(L); + if (SE) { + if (ULO.ForgetAllSCEV) + SE->forgetAllLoops(); + else + SE->forgetTopmostLoop(L); + } - bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); - BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); + bool ContinueOnTrue; + bool LatchIsExiting = BI->isConditional(); + BasicBlock *LoopExit = nullptr; + if (LatchIsExiting) { + ContinueOnTrue = L->contains(BI->getSuccessor(0)); + LoopExit = BI->getSuccessor(ContinueOnTrue); + } else { + NumUnrolledWithHeader++; + ContinueOnTrue = L->contains(HeaderBI->getSuccessor(0)); + LoopExit = HeaderBI->getSuccessor(ContinueOnTrue); + } // For the first iteration of the loop, we should use the precloned values for // PHI nodes. Insert associations now. @@ -569,11 +547,23 @@ LoopUnrollResult llvm::UnrollLoop( OrigPHINode.push_back(cast<PHINode>(I)); } - std::vector<BasicBlock*> Headers; - std::vector<BasicBlock*> Latches; + std::vector<BasicBlock *> Headers; + std::vector<BasicBlock *> HeaderSucc; + std::vector<BasicBlock *> Latches; Headers.push_back(Header); Latches.push_back(LatchBlock); + if (!LatchIsExiting) { + auto *Term = cast<BranchInst>(Header->getTerminator()); + if (Term->isUnconditional() || L->contains(Term->getSuccessor(0))) { + assert(L->contains(Term->getSuccessor(0))); + HeaderSucc.push_back(Term->getSuccessor(0)); + } else { + assert(L->contains(Term->getSuccessor(1))); + HeaderSucc.push_back(Term->getSuccessor(1)); + } + } + // The current on-the-fly SSA update requires blocks to be processed in // reverse postorder so that LastValueMap contains the correct value at each // exit. @@ -599,7 +589,7 @@ LoopUnrollResult llvm::UnrollLoop( for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) if (const DILocation *DIL = I.getDebugLoc()) { - auto NewDIL = DIL->cloneWithDuplicationFactor(Count); + auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(ULO.Count); if (NewDIL) I.setDebugLoc(NewDIL.getValue()); else @@ -608,7 +598,7 @@ LoopUnrollResult llvm::UnrollLoop( << DIL->getFilename() << " Line: " << DIL->getLine()); } - for (unsigned It = 1; It != Count; ++It) { + for (unsigned It = 1; It != ULO.Count; ++It) { std::vector<BasicBlock*> NewBlocks; SmallDenseMap<const Loop *, Loop *, 4> NewLoops; NewLoops[L] = L; @@ -663,6 +653,13 @@ LoopUnrollResult llvm::UnrollLoop( if (*BB == LatchBlock) Latches.push_back(New); + // Keep track of the successor of the new header in the current iteration. + for (auto *Pred : predecessors(*BB)) + if (Pred == Header) { + HeaderSucc.push_back(New); + break; + } + NewBlocks.push_back(New); UnrolledLoopBlocks.push_back(New); @@ -699,8 +696,7 @@ LoopUnrollResult llvm::UnrollLoop( if (CompletelyUnroll) { PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader)); Header->getInstList().erase(PN); - } - else if (Count > 1) { + } else if (ULO.Count > 1) { Value *InVal = PN->removeIncomingValue(LatchBlock, false); // If this value was defined in the loop, take the value defined by the // last iteration of the loop. @@ -713,39 +709,11 @@ LoopUnrollResult llvm::UnrollLoop( } } - // Now that all the basic blocks for the unrolled iterations are in place, - // set up the branches to connect them. - for (unsigned i = 0, e = Latches.size(); i != e; ++i) { - // The original branch was replicated in each unrolled iteration. - BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator()); - - // The branch destination. - unsigned j = (i + 1) % e; - BasicBlock *Dest = Headers[j]; - bool NeedConditional = true; - - if (RuntimeTripCount && j != 0) { - NeedConditional = false; - } - - // For a complete unroll, make the last iteration end with a branch - // to the exit block. - if (CompletelyUnroll) { - if (j == 0) - Dest = LoopExit; - // If using trip count upper bound to completely unroll, we need to keep - // the conditional branch except the last one because the loop may exit - // after any iteration. - assert(NeedConditional && - "NeedCondition cannot be modified by both complete " - "unrolling and runtime unrolling"); - NeedConditional = (PreserveCondBr && j && !(PreserveOnlyFirst && i != 0)); - } else if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) { - // If we know the trip count or a multiple of it, we can safely use an - // unconditional branch for some iterations. - NeedConditional = false; - } - + auto setDest = [LoopExit, ContinueOnTrue](BasicBlock *Src, BasicBlock *Dest, + ArrayRef<BasicBlock *> NextBlocks, + BasicBlock *CurrentHeader, + bool NeedConditional) { + auto *Term = cast<BranchInst>(Src->getTerminator()); if (NeedConditional) { // Update the conditional branch's successor for the following // iteration. @@ -753,9 +721,9 @@ LoopUnrollResult llvm::UnrollLoop( } else { // Remove phi operands at this loop exit if (Dest != LoopExit) { - BasicBlock *BB = Latches[i]; - for (BasicBlock *Succ: successors(BB)) { - if (Succ == Headers[i]) + BasicBlock *BB = Src; + for (BasicBlock *Succ : successors(BB)) { + if (Succ == CurrentHeader) continue; for (PHINode &Phi : Succ->phis()) Phi.removeIncomingValue(BB, false); @@ -765,13 +733,97 @@ LoopUnrollResult llvm::UnrollLoop( BranchInst::Create(Dest, Term); Term->eraseFromParent(); } + }; + + // Now that all the basic blocks for the unrolled iterations are in place, + // set up the branches to connect them. + if (LatchIsExiting) { + // Set up latches to branch to the new header in the unrolled iterations or + // the loop exit for the last latch in a fully unrolled loop. + for (unsigned i = 0, e = Latches.size(); i != e; ++i) { + // The branch destination. + unsigned j = (i + 1) % e; + BasicBlock *Dest = Headers[j]; + bool NeedConditional = true; + + if (RuntimeTripCount && j != 0) { + NeedConditional = false; + } + + // For a complete unroll, make the last iteration end with a branch + // to the exit block. + if (CompletelyUnroll) { + if (j == 0) + Dest = LoopExit; + // If using trip count upper bound to completely unroll, we need to keep + // the conditional branch except the last one because the loop may exit + // after any iteration. + assert(NeedConditional && + "NeedCondition cannot be modified by both complete " + "unrolling and runtime unrolling"); + NeedConditional = + (ULO.PreserveCondBr && j && !(ULO.PreserveOnlyFirst && i != 0)); + } else if (j != BreakoutTrip && + (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) { + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. + NeedConditional = false; + } + + setDest(Latches[i], Dest, Headers, Headers[i], NeedConditional); + } + } else { + // Setup headers to branch to their new successors in the unrolled + // iterations. + for (unsigned i = 0, e = Headers.size(); i != e; ++i) { + // The branch destination. + unsigned j = (i + 1) % e; + BasicBlock *Dest = HeaderSucc[i]; + bool NeedConditional = true; + + if (RuntimeTripCount && j != 0) + NeedConditional = false; + + if (CompletelyUnroll) + // We cannot drop the conditional branch for the last condition, as we + // may have to execute the loop body depending on the condition. + NeedConditional = j == 0 || ULO.PreserveCondBr; + else if (j != BreakoutTrip && + (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. + NeedConditional = false; + + setDest(Headers[i], Dest, Headers, Headers[i], NeedConditional); + } + + // Set up latches to branch to the new header in the unrolled iterations or + // the loop exit for the last latch in a fully unrolled loop. + + for (unsigned i = 0, e = Latches.size(); i != e; ++i) { + // The original branch was replicated in each unrolled iteration. + BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator()); + + // The branch destination. + unsigned j = (i + 1) % e; + BasicBlock *Dest = Headers[j]; + + // When completely unrolling, the last latch becomes unreachable. + if (CompletelyUnroll && j == 0) + new UnreachableInst(Term->getContext(), Term); + else + // Replace the conditional branch with an unconditional one. + BranchInst::Create(Dest, Term); + + Term->eraseFromParent(); + } } // Update dominators of blocks we might reach through exits. // Immediate dominator of such block might change, because we add more // routes which can lead to the exit: we can now reach it from the copied // iterations too. - if (DT && Count > 1) { + if (DT && ULO.Count > 1) { for (auto *BB : OriginalLoopBlocks) { auto *BBDomNode = DT->getNode(BB); SmallVector<BasicBlock *, 16> ChildrenToUpdate; @@ -781,7 +833,9 @@ LoopUnrollResult llvm::UnrollLoop( ChildrenToUpdate.push_back(ChildBB); } BasicBlock *NewIDom; - if (BB == LatchBlock) { + BasicBlock *&TermBlock = LatchIsExiting ? LatchBlock : Header; + auto &TermBlocks = LatchIsExiting ? Latches : Headers; + if (BB == TermBlock) { // The latch is special because we emit unconditional branches in // some cases where the original loop contained a conditional branch. // Since the latch is always at the bottom of the loop, if the latch @@ -789,11 +843,13 @@ LoopUnrollResult llvm::UnrollLoop( // must also be a latch. Specifically, the dominator is the first // latch which ends in a conditional branch, or the last latch if // there is no such latch. - NewIDom = Latches.back(); - for (BasicBlock *IterLatch : Latches) { - Instruction *Term = IterLatch->getTerminator(); + // For loops exiting from the header, we limit the supported loops + // to have a single exiting block. + NewIDom = TermBlocks.back(); + for (BasicBlock *Iter : TermBlocks) { + Instruction *Term = Iter->getTerminator(); if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) { - NewIDom = IterLatch; + NewIDom = Iter; break; } } @@ -810,14 +866,20 @@ LoopUnrollResult llvm::UnrollLoop( } assert(!DT || !UnrollVerifyDomtree || - DT->verify(DominatorTree::VerificationLevel::Fast)); + DT->verify(DominatorTree::VerificationLevel::Fast)); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); // Merge adjacent basic blocks, if possible. for (BasicBlock *Latch : Latches) { - BranchInst *Term = cast<BranchInst>(Latch->getTerminator()); - if (Term->isUnconditional()) { + BranchInst *Term = dyn_cast<BranchInst>(Latch->getTerminator()); + assert((Term || + (CompletelyUnroll && !LatchIsExiting && Latch == Latches.back())) && + "Need a branch as terminator, except when fully unrolling with " + "unconditional latch"); + if (Term && Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); - if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) { + BasicBlock *Fold = Dest->getUniquePredecessor(); + if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); UnrolledLoopBlocks.erase(std::remove(UnrolledLoopBlocks.begin(), @@ -829,8 +891,8 @@ LoopUnrollResult llvm::UnrollLoop( // At this point, the code is well formed. We now simplify the unrolled loop, // doing constant propagation and dead code elimination as we go. - simplifyLoopAfterUnroll(L, !CompletelyUnroll && (Count > 1 || Peeled), LI, SE, - DT, AC); + simplifyLoopAfterUnroll(L, !CompletelyUnroll && (ULO.Count > 1 || Peeled), LI, + SE, DT, AC); NumCompletelyUnrolled += CompletelyUnroll; ++NumUnrolled; @@ -878,11 +940,11 @@ LoopUnrollResult llvm::UnrollLoop( // TODO: That potentially might be compile-time expensive. We should try // to fix the loop-simplified form incrementally. - simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); + simplifyLoop(OuterL, DT, LI, SE, AC, nullptr, PreserveLCSSA); } else { // Simplify loops for which we might've broken loop-simplify form. for (Loop *SubLoop : LoopsToSimplify) - simplifyLoop(SubLoop, DT, LI, SE, AC, PreserveLCSSA); + simplifyLoop(SubLoop, DT, LI, SE, AC, nullptr, PreserveLCSSA); } } diff --git a/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/lib/Transforms/Utils/LoopUnrollAndJam.cpp index e26762639c13..ff49d83f25c5 100644 --- a/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -1,9 +1,8 @@ //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -198,8 +197,8 @@ LoopUnrollResult llvm::UnrollAndJamLoop( if (TripMultiple == 1 || TripMultiple % Count != 0) { if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false, /*UseEpilogRemainder*/ true, - UnrollRemainder, LI, SE, DT, AC, true, - EpilogueLoop)) { + UnrollRemainder, /*ForgetAllSCEV*/ false, + LI, SE, DT, AC, true, EpilogueLoop)) { LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be " "generated when assuming runtime trip count\n"); return LoopUnrollResult::Unmodified; @@ -301,7 +300,7 @@ LoopUnrollResult llvm::UnrollAndJamLoop( for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) if (const DILocation *DIL = I.getDebugLoc()) { - auto NewDIL = DIL->cloneWithDuplicationFactor(Count); + auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count); if (NewDIL) I.setDebugLoc(NewDIL.getValue()); else @@ -539,12 +538,14 @@ LoopUnrollResult llvm::UnrollAndJamLoop( MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end()); MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end()); MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end()); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); while (!MergeBlocks.empty()) { BasicBlock *BB = *MergeBlocks.begin(); BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator()); if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) { BasicBlock *Dest = Term->getSuccessor(0); - if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) { + BasicBlock *Fold = Dest->getUniquePredecessor(); + if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { // Don't remove BB and add Fold as they are the same BB assert(Fold == BB); (void)Fold; diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index 151a285af4e9..005306cf1898 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -1,9 +1,8 @@ //===- UnrollLoopPeel.cpp - Loop peeling utilities ------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -62,6 +61,10 @@ static cl::opt<unsigned> UnrollForcePeelCount( "unroll-force-peel-count", cl::init(0), cl::Hidden, cl::desc("Force a peel count regardless of profiling information.")); +static cl::opt<bool> UnrollPeelMultiDeoptExit( + "unroll-peel-multi-deopt-exit", cl::init(false), cl::Hidden, + cl::desc("Allow peeling of loops with multiple deopt exits.")); + // Designates that a Phi is estimated to become invariant after an "infinite" // number of loop iterations (i.e. only may become an invariant if the loop is // fully unrolled). @@ -74,6 +77,22 @@ bool llvm::canPeel(Loop *L) { if (!L->isLoopSimplifyForm()) return false; + if (UnrollPeelMultiDeoptExit) { + SmallVector<BasicBlock *, 4> Exits; + L->getUniqueNonLatchExitBlocks(Exits); + + if (!Exits.empty()) { + // Latch's terminator is a conditional branch, Latch is exiting and + // all non Latch exits ends up with deoptimize. + const BasicBlock *Latch = L->getLoopLatch(); + const BranchInst *T = dyn_cast<BranchInst>(Latch->getTerminator()); + return T && T->isConditional() && L->isLoopExiting(Latch) && + all_of(Exits, [](const BasicBlock *BB) { + return BB->getTerminatingDeoptimizeCall(); + }); + } + } + // Only peel loops that contain a single exit if (!L->getExitingBlock() || !L->getUniqueExitBlock()) return false; @@ -363,41 +382,89 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, unsigned IterNumber, unsigned AvgIters, uint64_t &PeeledHeaderWeight) { + if (!PeeledHeaderWeight) + return; // FIXME: Pick a more realistic distribution. // Currently the proportion of weight we assign to the fall-through // side of the branch drops linearly with the iteration number, and we use // a 0.9 fudge factor to make the drop-off less sharp... - if (PeeledHeaderWeight) { - uint64_t FallThruWeight = - PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9); - uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight; - PeeledHeaderWeight -= ExitWeight; - - unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(LatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight) - : MDB.createBranchWeights(FallThruWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - } + uint64_t FallThruWeight = + PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9); + uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight; + PeeledHeaderWeight -= ExitWeight; + + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(LatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight) + : MDB.createBranchWeights(FallThruWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); +} + +/// Initialize the weights. +/// +/// \param Header The header block. +/// \param LatchBR The latch branch. +/// \param AvgIters The average number of iterations we expect the loop to have. +/// \param[out] ExitWeight The # of times the edge from Latch to Exit is taken. +/// \param[out] CurHeaderWeight The # of times the header is executed. +static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR, + unsigned AvgIters, uint64_t &ExitWeight, + uint64_t &CurHeaderWeight) { + uint64_t TrueWeight, FalseWeight; + if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + return; + unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; + ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; + // The # of times the loop body executes is the sum of the exit block + // is taken and the # of times the backedges are taken. + CurHeaderWeight = TrueWeight + FalseWeight; +} + +/// Update the weights of original Latch block after peeling off all iterations. +/// +/// \param Header The header block. +/// \param LatchBR The latch branch. +/// \param ExitWeight The weight of the edge from Latch to Exit block. +/// \param CurHeaderWeight The # of time the header is executed. +static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR, + uint64_t ExitWeight, uint64_t CurHeaderWeight) { + // Adjust the branch weights on the loop exit. + if (!ExitWeight) + return; + + // The backedge count is the difference of current header weight and + // current loop exit weight. If the current header weight is smaller than + // the current loop exit weight, we mark the loop backedge weight as 1. + uint64_t BackEdgeWeight = 0; + if (ExitWeight < CurHeaderWeight) + BackEdgeWeight = CurHeaderWeight - ExitWeight; + else + BackEdgeWeight = 1; + MDBuilder MDB(LatchBR->getContext()); + unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1; + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); } /// Clones the body of the loop L, putting it between \p InsertTop and \p /// InsertBot. /// \param IterNumber The serial number of the iteration currently being /// peeled off. -/// \param Exit The exit block of the original loop. +/// \param ExitEdges The exit edges of the original loop. /// \param[out] NewBlocks A list of the blocks in the newly created clone /// \param[out] VMap The value map between the loop and the new clone. /// \param LoopBlocks A helper for DFS-traversal of the loop. /// \param LVMap A value-map that maps instructions from the original loop to /// instructions in the last peeled-off iteration. -static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, - BasicBlock *InsertBot, BasicBlock *Exit, - SmallVectorImpl<BasicBlock *> &NewBlocks, - LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, - ValueToValueMapTy &LVMap, DominatorTree *DT, - LoopInfo *LI) { +static void cloneLoopBlocks( + Loop *L, unsigned IterNumber, BasicBlock *InsertTop, BasicBlock *InsertBot, + SmallVectorImpl<std::pair<BasicBlock *, BasicBlock *> > &ExitEdges, + SmallVectorImpl<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, + ValueToValueMapTy &VMap, ValueToValueMapTy &LVMap, DominatorTree *DT, + LoopInfo *LI) { BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); BasicBlock *PreHeader = L->getLoopPreheader(); @@ -443,9 +510,11 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, // iteration (for every other iteration) BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); BranchInst *LatchBR = cast<BranchInst>(NewLatch->getTerminator()); - unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); - LatchBR->setSuccessor(HeaderIdx, InsertBot); - LatchBR->setSuccessor(1 - HeaderIdx, Exit); + for (unsigned idx = 0, e = LatchBR->getNumSuccessors(); idx < e; ++idx) + if (LatchBR->getSuccessor(idx) == Header) { + LatchBR->setSuccessor(idx, InsertBot); + break; + } if (DT) DT->changeImmediateDominator(InsertBot, NewLatch); @@ -476,14 +545,14 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, // we've just created. Note that this must happen *after* the incoming // values are adjusted, since the value going out of the latch may also be // a value coming into the header. - for (BasicBlock::iterator I = Exit->begin(); isa<PHINode>(I); ++I) { - PHINode *PHI = cast<PHINode>(I); - Value *LatchVal = PHI->getIncomingValueForBlock(Latch); - Instruction *LatchInst = dyn_cast<Instruction>(LatchVal); - if (LatchInst && L->contains(LatchInst)) - LatchVal = VMap[LatchVal]; - PHI->addIncoming(LatchVal, cast<BasicBlock>(VMap[Latch])); - } + for (auto Edge : ExitEdges) + for (PHINode &PHI : Edge.second->phis()) { + Value *LatchVal = PHI.getIncomingValueForBlock(Edge.first); + Instruction *LatchInst = dyn_cast<Instruction>(LatchVal); + if (LatchInst && L->contains(LatchInst)) + LatchVal = VMap[LatchVal]; + PHI.addIncoming(LatchVal, cast<BasicBlock>(VMap[Edge.first])); + } // LastValueMap is updated with the values for the current loop // which are used the next time this function is called. @@ -512,7 +581,20 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, BasicBlock *Header = L->getHeader(); BasicBlock *PreHeader = L->getLoopPreheader(); BasicBlock *Latch = L->getLoopLatch(); - BasicBlock *Exit = L->getUniqueExitBlock(); + SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> ExitEdges; + L->getExitEdges(ExitEdges); + + DenseMap<BasicBlock *, BasicBlock *> ExitIDom; + if (DT) { + assert(L->hasDedicatedExits() && "No dedicated exits?"); + for (auto Edge : ExitEdges) { + if (ExitIDom.count(Edge.second)) + continue; + BasicBlock *BB = DT->getNode(Edge.second)->getIDom()->getBlock(); + assert(L->contains(BB) && "IDom is not in a loop"); + ExitIDom[Edge.second] = BB; + } + } Function *F = Header->getParent(); @@ -577,16 +659,8 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // newly created branches. BranchInst *LatchBR = cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator()); - unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); - - uint64_t TrueWeight, FalseWeight; uint64_t ExitWeight = 0, CurHeaderWeight = 0; - if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { - ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; - // The # of times the loop body executes is the sum of the exit block - // weight and the # of times the backedges are taken. - CurHeaderWeight = TrueWeight + FalseWeight; - } + initBranchWeights(Header, LatchBR, PeelCount, ExitWeight, CurHeaderWeight); // For each peeled-off iteration, make a copy of the loop. for (unsigned Iter = 0; Iter < PeelCount; ++Iter) { @@ -602,8 +676,8 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, else CurHeaderWeight = 1; - cloneLoopBlocks(L, Iter, InsertTop, InsertBot, Exit, - NewBlocks, LoopBlocks, VMap, LVMap, DT, LI); + cloneLoopBlocks(L, Iter, InsertTop, InsertBot, ExitEdges, NewBlocks, + LoopBlocks, VMap, LVMap, DT, LI); // Remap to use values from the current iteration instead of the // previous one. @@ -614,7 +688,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // latter is the first cloned loop body, as original PreHeader dominates // the original loop body. if (Iter == 0) - DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch])); + for (auto Exit : ExitIDom) + DT->changeImmediateDominator(Exit.first, + cast<BasicBlock>(LVMap[Exit.second])); #ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); #endif @@ -645,36 +721,22 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, if (LatchInst && L->contains(LatchInst)) NewVal = LVMap[LatchInst]; - PHI->setIncomingValue(PHI->getBasicBlockIndex(NewPreHeader), NewVal); + PHI->setIncomingValueForBlock(NewPreHeader, NewVal); } - // Adjust the branch weights on the loop exit. - if (ExitWeight) { - // The backedge count is the difference of current header weight and - // current loop exit weight. If the current header weight is smaller than - // the current loop exit weight, we mark the loop backedge weight as 1. - uint64_t BackEdgeWeight = 0; - if (ExitWeight < CurHeaderWeight) - BackEdgeWeight = CurHeaderWeight - ExitWeight; - else - BackEdgeWeight = 1; - MDBuilder MDB(LatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) - : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); - LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - } + fixupBranchWeights(Header, LatchBR, ExitWeight, CurHeaderWeight); - // If the loop is nested, we changed the parent loop, update SE. - if (Loop *ParentLoop = L->getParentLoop()) { - SE->forgetLoop(ParentLoop); + if (Loop *ParentLoop = L->getParentLoop()) + L = ParentLoop; - // FIXME: Incrementally update loop-simplify - simplifyLoop(ParentLoop, DT, LI, SE, AC, PreserveLCSSA); - } else { - // FIXME: Incrementally update loop-simplify - simplifyLoop(L, DT, LI, SE, AC, PreserveLCSSA); - } + // We modified the loop, update SE. + SE->forgetTopmostLoop(L); + + // Finally DomtTree must be correct. + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); + + // FIXME: Incrementally update loop-simplify + simplifyLoop(L, DT, LI, SE, AC, nullptr, PreserveLCSSA); NumPeeled++; diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 00d2fd2fdbac..d22fdb4d52dc 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -1,9 +1,8 @@ //===-- UnrollLoopRuntime.cpp - Runtime Loop unrolling utilities ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -125,11 +124,10 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // Update the existing PHI node operand with the value from the // new PHI node. How this is done depends on if the existing // PHI node is in the original loop block, or the exit block. - if (L->contains(&PN)) { - PN.setIncomingValue(PN.getBasicBlockIndex(NewPreHeader), NewPN); - } else { + if (L->contains(&PN)) + PN.setIncomingValueForBlock(NewPreHeader, NewPN); + else PN.addIncoming(NewPN, PrologExit); - } } } @@ -265,7 +263,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Update the existing PHI node operand with the value from the new PHI // node. Corresponding instruction in epilog loop should be PHI. PHINode *VPN = cast<PHINode>(VMap[&PN]); - VPN->setIncomingValue(VPN->getBasicBlockIndex(EpilogPreHeader), NewPN); + VPN->setIncomingValueForBlock(EpilogPreHeader, NewPN); } } @@ -426,10 +424,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, /// Returns true if we can safely unroll a multi-exit/exiting loop. OtherExits /// is populated with all the loop exit blocks other than the LatchExit block. -static bool -canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, - BasicBlock *LatchExit, bool PreserveLCSSA, - bool UseEpilogRemainder) { +static bool canSafelyUnrollMultiExitLoop(Loop *L, BasicBlock *LatchExit, + bool PreserveLCSSA, + bool UseEpilogRemainder) { // We currently have some correctness constrains in unrolling a multi-exit // loop. Check for these below. @@ -437,11 +434,6 @@ canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, // We rely on LCSSA form being preserved when the exit blocks are transformed. if (!PreserveLCSSA) return false; - SmallVector<BasicBlock *, 4> Exits; - L->getUniqueExitBlocks(Exits); - for (auto *BB : Exits) - if (BB != LatchExit) - OtherExits.push_back(BB); // TODO: Support multiple exiting blocks jumping to the `LatchExit` when // UnrollRuntimeMultiExit is true. This will need updating the logic in @@ -471,9 +463,8 @@ static bool canProfitablyUnrollMultiExitLoop( bool PreserveLCSSA, bool UseEpilogRemainder) { #if !defined(NDEBUG) - SmallVector<BasicBlock *, 8> OtherExitsDummyCheck; - assert(canSafelyUnrollMultiExitLoop(L, OtherExitsDummyCheck, LatchExit, - PreserveLCSSA, UseEpilogRemainder) && + assert(canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, + UseEpilogRemainder) && "Should be safe to unroll before checking profitability!"); #endif @@ -554,10 +545,10 @@ static bool canProfitablyUnrollMultiExitLoop( bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, bool AllowExpensiveTripCount, bool UseEpilogRemainder, - bool UnrollRemainder, LoopInfo *LI, - ScalarEvolution *SE, DominatorTree *DT, - AssumptionCache *AC, bool PreserveLCSSA, - Loop **ResultLoop) { + bool UnrollRemainder, bool ForgetAllSCEV, + LoopInfo *LI, ScalarEvolution *SE, + DominatorTree *DT, AssumptionCache *AC, + bool PreserveLCSSA, Loop **ResultLoop) { LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); LLVM_DEBUG(L->dump()); LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" @@ -597,8 +588,9 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // These are exit blocks other than the target of the latch exiting block. SmallVector<BasicBlock *, 4> OtherExits; + L->getUniqueNonLatchExitBlocks(OtherExits); bool isMultiExitUnrollingEnabled = - canSafelyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, + canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, UseEpilogRemainder) && canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, UseEpilogRemainder); @@ -939,23 +931,24 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, if (OtherExits.size() > 0) { // Generate dedicated exit blocks for the original loop, to preserve // LoopSimplifyForm. - formDedicatedExitBlocks(L, DT, LI, PreserveLCSSA); + formDedicatedExitBlocks(L, DT, LI, nullptr, PreserveLCSSA); // Generate dedicated exit blocks for the remainder loop if one exists, to // preserve LoopSimplifyForm. if (remainderLoop) - formDedicatedExitBlocks(remainderLoop, DT, LI, PreserveLCSSA); + formDedicatedExitBlocks(remainderLoop, DT, LI, nullptr, PreserveLCSSA); } auto UnrollResult = LoopUnrollResult::Unmodified; if (remainderLoop && UnrollRemainder) { LLVM_DEBUG(dbgs() << "Unrolling remainder loop\n"); UnrollResult = - UnrollLoop(remainderLoop, /*Count*/ Count - 1, /*TripCount*/ Count - 1, - /*Force*/ false, /*AllowRuntime*/ false, - /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, - /*PreserveOnlyFirst*/ false, /*TripMultiple*/ 1, - /*PeelCount*/ 0, /*UnrollRemainder*/ false, LI, SE, DT, AC, - /*ORE*/ nullptr, PreserveLCSSA); + UnrollLoop(remainderLoop, + {/*Count*/ Count - 1, /*TripCount*/ Count - 1, + /*Force*/ false, /*AllowRuntime*/ false, + /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, + /*PreserveOnlyFirst*/ false, /*TripMultiple*/ 1, + /*PeelCount*/ 0, /*UnrollRemainder*/ false, ForgetAllSCEV}, + LI, SE, DT, AC, /*ORE*/ nullptr, PreserveLCSSA); } if (ResultLoop && UnrollResult != LoopUnrollResult::FullyUnrolled) diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index a93d1aeb62ef..ec226e65f650 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -1,9 +1,8 @@ //===-- LoopUtils.cpp - Loop Utility functions -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -15,10 +14,12 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" @@ -27,7 +28,6 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -47,6 +47,7 @@ using namespace llvm::PatternMatch; static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced"; bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, + MemorySSAUpdater *MSSAU, bool PreserveLCSSA) { bool Changed = false; @@ -66,6 +67,9 @@ bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, if (isa<IndirectBrInst>(PredBB->getTerminator())) // We cannot rewrite exiting edges from an indirectbr. return false; + if (isa<CallBrInst>(PredBB->getTerminator())) + // We cannot rewrite exiting edges from a callbr. + return false; InLoopPredecessors.push_back(PredBB); } else { @@ -79,7 +83,7 @@ bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, return false; auto *NewExitBB = SplitBlockPredecessors( - BB, InLoopPredecessors, ".loopexit", DT, LI, nullptr, PreserveLCSSA); + BB, InLoopPredecessors, ".loopexit", DT, LI, MSSAU, PreserveLCSSA); if (!NewExitBB) LLVM_DEBUG( @@ -217,7 +221,10 @@ static Optional<bool> getOptionalBoolLoopAttribute(const Loop *TheLoop, // When the value is absent it is interpreted as 'attribute set'. return true; case 2: - return mdconst::extract_or_null<ConstantInt>(MD->getOperand(1).get()); + if (ConstantInt *IntMD = + mdconst::extract_or_null<ConstantInt>(MD->getOperand(1).get())) + return IntMD->getZExtValue(); + return true; } llvm_unreachable("unexpected number of options"); } @@ -376,17 +383,17 @@ TransformationMode llvm::hasVectorizeTransformation(Loop *L) { Optional<int> InterleaveCount = getOptionalIntLoopAttribute(L, "llvm.loop.interleave.count"); - if (Enable == true) { - // 'Forcing' vector width and interleave count to one effectively disables - // this tranformation. - if (VectorizeWidth == 1 && InterleaveCount == 1) - return TM_SuppressedByUser; - return TM_ForcedByUser; - } + // 'Forcing' vector width and interleave count to one effectively disables + // this tranformation. + if (Enable == true && VectorizeWidth == 1 && InterleaveCount == 1) + return TM_SuppressedByUser; if (getBooleanLoopAttribute(L, "llvm.loop.isvectorized")) return TM_Disable; + if (Enable == true) + return TM_ForcedByUser; + if (VectorizeWidth == 1 && InterleaveCount == 1) return TM_Disable; @@ -528,10 +535,9 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); if (DT) { // Update the dominator tree by informing it about the new edge from the - // preheader to the exit. - DTU.insertEdge(Preheader, ExitBlock); - // Inform the dominator tree about the removed edge. - DTU.deleteEdge(Preheader, L->getHeader()); + // preheader to the exit and the removed edge. + DTU.applyUpdates({{DominatorTree::Insert, Preheader, ExitBlock}, + {DominatorTree::Delete, Preheader, L->getHeader()}}); } // Use a map to unique and a vector to guarantee deterministic ordering. @@ -578,10 +584,14 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, // dbg.value truncates the range of any dbg.value before the loop where the // loop used to be. This is particularly important for constant values. DIBuilder DIB(*ExitBlock->getModule()); + Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI(); + assert(InsertDbgValueBefore && + "There should be a non-PHI instruction in exit block, else these " + "instructions will have no parent."); for (auto *DVI : DeadDebugInst) - DIB.insertDbgValueIntrinsic( - UndefValue::get(Builder.getInt32Ty()), DVI->getVariable(), - DVI->getExpression(), DVI->getDebugLoc(), ExitBlock->getFirstNonPHI()); + DIB.insertDbgValueIntrinsic(UndefValue::get(Builder.getInt32Ty()), + DVI->getVariable(), DVI->getExpression(), + DVI->getDebugLoc(), InsertDbgValueBefore); // Remove the block from the reference counting scheme, so that we can // delete it freely later. @@ -611,20 +621,28 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, } Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { - // Only support loops with a unique exiting block, and a latch. - if (!L->getExitingBlock()) - return None; + // Support loops with an exiting latch and other existing exists only + // deoptimize. // Get the branch weights for the loop's backedge. - BranchInst *LatchBR = - dyn_cast<BranchInst>(L->getLoopLatch()->getTerminator()); - if (!LatchBR || LatchBR->getNumSuccessors() != 2) + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return None; + BranchInst *LatchBR = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!LatchBR || LatchBR->getNumSuccessors() != 2 || !L->isLoopExiting(Latch)) return None; assert((LatchBR->getSuccessor(0) == L->getHeader() || LatchBR->getSuccessor(1) == L->getHeader()) && "At least one edge out of the latch must go to the header"); + SmallVector<BasicBlock *, 4> ExitBlocks; + L->getUniqueNonLatchExitBlocks(ExitBlocks); + if (any_of(ExitBlocks, [](const BasicBlock *EB) { + return !EB->getTerminatingDeoptimizeCall(); + })) + return None; + // To estimate the number of times the loop body was executed, we want to // know the number of times the backedge was taken, vs. the number of times // we exited the loop. @@ -665,16 +683,6 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } -/// Adds a 'fast' flag to floating point operations. -static Value *addFastMathFlag(Value *V) { - if (isa<FPMathOperator>(V)) { - FastMathFlags Flags; - Flags.setFast(); - cast<Instruction>(V)->setFastMathFlags(Flags); - } - return V; -} - Value *llvm::createMinMaxOp(IRBuilder<> &Builder, RecurrenceDescriptor::MinMaxRecurrenceKind RK, Value *Left, Value *Right) { @@ -778,9 +786,9 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, ConstantVector::get(ShuffleMask), "rdx.shuf"); if (Op != Instruction::ICmp && Op != Instruction::FCmp) { - // Floating point operations had to be 'fast' to enable the reduction. - TmpVec = addFastMathFlag(Builder.CreateBinOp((Instruction::BinaryOps)Op, - TmpVec, Shuf, "bin.rdx")); + // The builder propagates its fast-math-flags setting. + TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf, + "bin.rdx"); } else { assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid && "Invalid min/max"); @@ -801,13 +809,9 @@ Value *llvm::createSimpleTargetReduction( ArrayRef<Value *> RedOps) { assert(isa<VectorType>(Src->getType()) && "Type must be a vector"); - Value *ScalarUdf = UndefValue::get(Src->getType()->getVectorElementType()); std::function<Value *()> BuildFunc; using RD = RecurrenceDescriptor; RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid; - // TODO: Support creating ordered reductions. - FastMathFlags FMFFast; - FMFFast.setFast(); switch (Opcode) { case Instruction::Add: @@ -827,15 +831,15 @@ Value *llvm::createSimpleTargetReduction( break; case Instruction::FAdd: BuildFunc = [&]() { - auto Rdx = Builder.CreateFAddReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); + auto Rdx = Builder.CreateFAddReduce( + Constant::getNullValue(Src->getType()->getVectorElementType()), Src); return Rdx; }; break; case Instruction::FMul: BuildFunc = [&]() { - auto Rdx = Builder.CreateFMulReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); + Type *Ty = Src->getType()->getVectorElementType(); + auto Rdx = Builder.CreateFMulReduce(ConstantFP::get(Ty, 1.0), Src); return Rdx; }; break; @@ -880,6 +884,12 @@ Value *llvm::createTargetReduction(IRBuilder<> &B, RD::RecurrenceKind RecKind = Desc.getRecurrenceKind(); TargetTransformInfo::ReductionFlags Flags; Flags.NoNaN = NoNaN; + + // All ops in the reduction inherit fast-math-flags from the recurrence + // descriptor. + IRBuilder<>::FastMathFlagGuard FMFGuard(B); + B.setFastMathFlags(Desc.getFastMathFlags()); + switch (RecKind) { case RD::RK_FloatAdd: return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags); diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp index abbcd5f9e3b8..a9a480a4b7f9 100644 --- a/lib/Transforms/Utils/LoopVersioning.cpp +++ b/lib/Transforms/Utils/LoopVersioning.cpp @@ -1,9 +1,8 @@ //===- LoopVersioning.cpp - Utility to version a loop ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -281,8 +280,9 @@ public: bool Changed = false; for (Loop *L : Worklist) { const LoopAccessInfo &LAI = LAA->getInfo(L); - if (L->isLoopSimplifyForm() && (LAI.getNumRuntimePointerChecks() || - !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { + if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && + (LAI.getNumRuntimePointerChecks() || + !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { LoopVersioning LVer(LAI, L, LI, DT, SE); LVer.versionLoop(); LVer.annotateLoopWithNoAlias(); diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp index c852d538b0d1..fe67e191dc62 100644 --- a/lib/Transforms/Utils/LowerInvoke.cpp +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -1,9 +1,8 @@ //===- LowerInvoke.cpp - Eliminate Invoke instructions --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -53,7 +52,8 @@ static bool runImpl(Function &F) { II->getOperandBundlesAsDefs(OpBundles); // Insert a normal call instruction... CallInst *NewCall = - CallInst::Create(II->getCalledValue(), CallArgs, OpBundles, "", II); + CallInst::Create(II->getFunctionType(), II->getCalledValue(), + CallArgs, OpBundles, "", II); NewCall->takeName(II); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 661b4fa5bcb7..0cc085dc366c 100644 --- a/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -1,9 +1,8 @@ //===- LowerMemIntrinsics.cpp ----------------------------------*- C++ -*--===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -73,7 +72,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, // Loop Body Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); - Value *Load = LoopBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *Load = LoopBuilder.CreateLoad(LoopOpType, SrcGEP, SrcIsVolatile); Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile); @@ -115,7 +114,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, : RBuilder.CreateBitCast(SrcAddr, SrcPtrType); Value *SrcGEP = RBuilder.CreateInBoundsGEP( OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex)); - Value *Load = RBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *Load = RBuilder.CreateLoad(OpTy, SrcGEP, SrcIsVolatile); // Cast destination to operand type and store. PointerType *DstPtrType = PointerType::get(OpTy, DstAS); @@ -182,7 +181,7 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, LoopIndex->addIncoming(ConstantInt::get(CopyLenType, 0U), PreLoopBB); Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); - Value *Load = LoopBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *Load = LoopBuilder.CreateLoad(LoopOpType, SrcGEP, SrcIsVolatile); Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); LoopBuilder.CreateStore(Load, DstGEP, DstIsVolatile); @@ -235,7 +234,7 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex); Value *SrcGEP = ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset); - Value *Load = ResBuilder.CreateLoad(SrcGEP, SrcIsVolatile); + Value *Load = ResBuilder.CreateLoad(Int8Type, SrcGEP, SrcIsVolatile); Value *DstGEP = ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset); ResBuilder.CreateStore(Load, DstGEP, DstIsVolatile); @@ -293,6 +292,8 @@ static void createMemMoveLoop(Instruction *InsertBefore, BasicBlock *OrigBB = InsertBefore->getParent(); Function *F = OrigBB->getParent(); + Type *EltTy = cast<PointerType>(SrcAddr->getType())->getElementType(); + // Create the a comparison of src and dst, based on which we jump to either // the forward-copy part of the function (if src >= dst) or the backwards-copy // part (if src < dst). @@ -331,9 +332,10 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *IndexPtr = LoopBuilder.CreateSub( LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); Value *Element = LoopBuilder.CreateLoad( - LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element"); - LoopBuilder.CreateStore(Element, - LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr)); + EltTy, LoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, IndexPtr), + "element"); + LoopBuilder.CreateStore( + Element, LoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, IndexPtr)); LoopBuilder.CreateCondBr( LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)), ExitBB, LoopBB); @@ -348,9 +350,10 @@ static void createMemMoveLoop(Instruction *InsertBefore, IRBuilder<> FwdLoopBuilder(FwdLoopBB); PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr"); Value *FwdElement = FwdLoopBuilder.CreateLoad( - FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element"); + EltTy, FwdLoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, FwdCopyPhi), + "element"); FwdLoopBuilder.CreateStore( - FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi)); + FwdElement, FwdLoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, FwdCopyPhi)); Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd( FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment"); FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen), diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index d019a44fc705..8256e3b5f5af 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -1,9 +1,8 @@ //===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -17,8 +16,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -28,6 +31,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -58,9 +62,8 @@ static bool IsInRanges(const IntRange &R, // Find the first range whose High field is >= R.High, // then check if the Low field is <= R.Low. If so, we // have a Range that covers R. - auto I = std::lower_bound( - Ranges.begin(), Ranges.end(), R, - [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); + auto I = llvm::lower_bound( + Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; }); return I != Ranges.end() && I->Low <= R.Low; } @@ -78,6 +81,10 @@ namespace { bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LazyValueInfoWrapperPass>(); + } + struct CaseRange { ConstantInt* Low; ConstantInt* High; @@ -91,15 +98,18 @@ namespace { using CaseItr = std::vector<CaseRange>::iterator; private: - void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList); + void processSwitchInst(SwitchInst *SI, + SmallPtrSetImpl<BasicBlock *> &DeleteList, + AssumptionCache *AC, LazyValueInfo *LVI); BasicBlock *switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ConstantInt *UpperBound, Value *Val, BasicBlock *Predecessor, BasicBlock *OrigBlock, BasicBlock *Default, const std::vector<IntRange> &UnreachableRanges); - BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, - BasicBlock *Default); + BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, + ConstantInt *LowerBound, ConstantInt *UpperBound, + BasicBlock *OrigBlock, BasicBlock *Default); unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); }; @@ -121,8 +131,12 @@ char LowerSwitch::ID = 0; // Publicly exposed interface to pass... char &llvm::LowerSwitchID = LowerSwitch::ID; -INITIALIZE_PASS(LowerSwitch, "lowerswitch", - "Lower SwitchInst's to branches", false, false) +INITIALIZE_PASS_BEGIN(LowerSwitch, "lowerswitch", + "Lower SwitchInst's to branches", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LowerSwitch, "lowerswitch", + "Lower SwitchInst's to branches", false, false) // createLowerSwitchPass - Interface to this file... FunctionPass *llvm::createLowerSwitchPass() { @@ -130,6 +144,17 @@ FunctionPass *llvm::createLowerSwitchPass() { } bool LowerSwitch::runOnFunction(Function &F) { + LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); + AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; + // Prevent LazyValueInfo from using the DominatorTree as LowerSwitch does not + // preserve it and it becomes stale (when available) pretty much immediately. + // Currently the DominatorTree is only used by LowerSwitch indirectly via LVI + // and computeKnownBits to refine isValidAssumeForContext's results. Given + // that the latter can handle some of the simple cases w/o a DominatorTree, + // it's easier to refrain from using the tree than to keep it up to date. + LVI->disableDT(); + bool Changed = false; SmallPtrSet<BasicBlock*, 8> DeleteList; @@ -143,11 +168,12 @@ bool LowerSwitch::runOnFunction(Function &F) { if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { Changed = true; - processSwitchInst(SI, DeleteList); + processSwitchInst(SI, DeleteList, AC, LVI); } } for (BasicBlock* BB: DeleteList) { + LVI->eraseBlock(BB); DeleteDeadBlock(BB); } @@ -160,10 +186,11 @@ static raw_ostream &operator<<(raw_ostream &O, const LowerSwitch::CaseVector &C) { O << "["; - for (LowerSwitch::CaseVector::const_iterator B = C.begin(), - E = C.end(); B != E; ) { - O << *B->Low << " -" << *B->High; - if (++B != E) O << ", "; + for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); + B != E;) { + O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]"; + if (++B != E) + O << ", "; } return O << "]"; @@ -179,8 +206,9 @@ static raw_ostream &operator<<(raw_ostream &O, /// 2) Removed if subsequent incoming values now share the same case, i.e., /// multiple outcome edges are condensed into one. This is necessary to keep the /// number of phi values equal to the number of branches to SuccBB. -static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, - unsigned NumMergedCases) { +static void +fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI()->getIterator(); I != IE; ++I) { @@ -222,6 +250,7 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, BasicBlock *Predecessor, BasicBlock *OrigBlock, BasicBlock *Default, const std::vector<IntRange> &UnreachableRanges) { + assert(LowerBound && UpperBound && "Bounds must be initialized"); unsigned Size = End - Begin; if (Size == 1) { @@ -231,13 +260,12 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, // because the bounds already tell us so. if (Begin->Low == LowerBound && Begin->High == UpperBound) { unsigned NumMergedCases = 0; - if (LowerBound && UpperBound) - NumMergedCases = - UpperBound->getSExtValue() - LowerBound->getSExtValue(); + NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } - return newLeafBlock(*Begin, Val, OrigBlock, Default); + return newLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, + Default); } unsigned Mid = Size / 2; @@ -247,8 +275,8 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); CaseRange &Pivot = *(Begin + Mid); - LLVM_DEBUG(dbgs() << "Pivot ==> " << Pivot.Low->getValue() << " -" - << Pivot.High->getValue() << "\n"); + LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", " + << Pivot.High->getValue() << "]\n"); // NewLowerBound here should never be the integer minimal value. // This is because it is computed from a case range that is never @@ -270,14 +298,10 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, NewUpperBound = LHS.back().High; } - LLVM_DEBUG(dbgs() << "LHS Bounds ==> "; if (LowerBound) { - dbgs() << LowerBound->getSExtValue(); - } else { dbgs() << "NONE"; } dbgs() << " - " - << NewUpperBound->getSExtValue() << "\n"; - dbgs() << "RHS Bounds ==> "; - dbgs() << NewLowerBound->getSExtValue() << " - "; if (UpperBound) { - dbgs() << UpperBound->getSExtValue() << "\n"; - } else { dbgs() << "NONE\n"; }); + LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " + << NewUpperBound->getSExtValue() << "]\n" + << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() + << ", " << UpperBound->getSExtValue() << "]\n"); // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. @@ -305,9 +329,11 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, /// switch's value == the case's value. If not, then it jumps to the default /// branch. At this point in the tree, the value can't be another valid case /// value, so the jump to the "default" branch is warranted. -BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, - BasicBlock* OrigBlock, - BasicBlock* Default) { +BasicBlock *LowerSwitch::newLeafBlock(CaseRange &Leaf, Value *Val, + ConstantInt *LowerBound, + ConstantInt *UpperBound, + BasicBlock *OrigBlock, + BasicBlock *Default) { Function* F = OrigBlock->getParent(); BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); @@ -320,10 +346,14 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, Leaf.Low, "SwitchLeaf"); } else { // Make range comparison - if (Leaf.Low->isMinValue(true /*isSigned*/)) { + if (Leaf.Low == LowerBound) { // Val >= Min && Val <= Hi --> Val <= Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, "SwitchLeaf"); + } else if (Leaf.High == UpperBound) { + // Val <= Max && Val >= Lo --> Val >= Lo + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, + "SwitchLeaf"); } else if (Leaf.Low->isZero()) { // Val >= 0 && Val <= Hi --> Val <=u Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, @@ -363,14 +393,20 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, return NewLeaf; } -/// Transform simple list of Cases into list of CaseRange's. +/// Transform simple list of \p SI's cases into list of CaseRange's \p Cases. +/// \post \p Cases wouldn't contain references to \p SI's default BB. +/// \returns Number of \p SI's cases that do not reference \p SI's default BB. unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { - unsigned numCmps = 0; + unsigned NumSimpleCases = 0; // Start with "simple" cases - for (auto Case : SI->cases()) + for (auto Case : SI->cases()) { + if (Case.getCaseSuccessor() == SI->getDefaultDest()) + continue; Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), Case.getCaseSuccessor())); + ++NumSimpleCases; + } llvm::sort(Cases, CaseCmp()); @@ -396,60 +432,88 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { Cases.erase(std::next(I), Cases.end()); } - for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { - if (I->Low != I->High) - // A range counts double, since it requires two compares. - ++numCmps; - } - - return numCmps; + return NumSimpleCases; } /// Replace the specified switch instruction with a sequence of chained if-then /// insts in a balanced binary search. void LowerSwitch::processSwitchInst(SwitchInst *SI, - SmallPtrSetImpl<BasicBlock*> &DeleteList) { - BasicBlock *CurBlock = SI->getParent(); - BasicBlock *OrigBlock = CurBlock; - Function *F = CurBlock->getParent(); + SmallPtrSetImpl<BasicBlock *> &DeleteList, + AssumptionCache *AC, LazyValueInfo *LVI) { + BasicBlock *OrigBlock = SI->getParent(); + Function *F = OrigBlock->getParent(); Value *Val = SI->getCondition(); // The value we are switching on... BasicBlock* Default = SI->getDefaultDest(); // Don't handle unreachable blocks. If there are successors with phis, this // would leave them behind with missing predecessors. - if ((CurBlock != &F->getEntryBlock() && pred_empty(CurBlock)) || - CurBlock->getSinglePredecessor() == CurBlock) { - DeleteList.insert(CurBlock); + if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) || + OrigBlock->getSinglePredecessor() == OrigBlock) { + DeleteList.insert(OrigBlock); return; } + // Prepare cases vector. + CaseVector Cases; + const unsigned NumSimpleCases = Clusterify(Cases, SI); + LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total non-default cases: " << NumSimpleCases + << "\nCase clusters: " << Cases << "\n"); + // If there is only the default destination, just branch. - if (!SI->getNumCases()) { - BranchInst::Create(Default, CurBlock); + if (Cases.empty()) { + BranchInst::Create(Default, OrigBlock); + // Remove all the references from Default's PHIs to OrigBlock, but one. + fixPhis(Default, OrigBlock, OrigBlock); SI->eraseFromParent(); return; } - // Prepare cases vector. - CaseVector Cases; - unsigned numCmps = Clusterify(Cases, SI); - LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << "\n"); - LLVM_DEBUG(dbgs() << "Cases: " << Cases << "\n"); - (void)numCmps; - ConstantInt *LowerBound = nullptr; ConstantInt *UpperBound = nullptr; - std::vector<IntRange> UnreachableRanges; + bool DefaultIsUnreachableFromSwitch = false; if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { // Make the bounds tightly fitted around the case value range, because we // know that the value passed to the switch must be exactly one of the case // values. - assert(!Cases.empty()); LowerBound = Cases.front().Low; UpperBound = Cases.back().High; + DefaultIsUnreachableFromSwitch = true; + } else { + // Constraining the range of the value being switched over helps eliminating + // unreachable BBs and minimizing the number of `add` instructions + // newLeafBlock ends up emitting. Running CorrelatedValuePropagation after + // LowerSwitch isn't as good, and also much more expensive in terms of + // compile time for the following reasons: + // 1. it processes many kinds of instructions, not just switches; + // 2. even if limited to icmp instructions only, it will have to process + // roughly C icmp's per switch, where C is the number of cases in the + // switch, while LowerSwitch only needs to call LVI once per switch. + const DataLayout &DL = F->getParent()->getDataLayout(); + KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI); + // TODO Shouldn't this create a signed range? + ConstantRange KnownBitsRange = + ConstantRange::fromKnownBits(Known, /*IsSigned=*/false); + const ConstantRange LVIRange = LVI->getConstantRange(Val, OrigBlock, SI); + ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange); + // We delegate removal of unreachable non-default cases to other passes. In + // the unlikely event that some of them survived, we just conservatively + // maintain the invariant that all the cases lie between the bounds. This + // may, however, still render the default case effectively unreachable. + APInt Low = Cases.front().Low->getValue(); + APInt High = Cases.back().High->getValue(); + APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); + APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); + + LowerBound = ConstantInt::get(SI->getContext(), Min); + UpperBound = ConstantInt::get(SI->getContext(), Max); + DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max); + } + + std::vector<IntRange> UnreachableRanges; + if (DefaultIsUnreachableFromSwitch) { DenseMap<BasicBlock *, unsigned> Popularity; unsigned MaxPop = 0; BasicBlock *PopSucc = nullptr; @@ -496,8 +560,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, #endif // As the default block in the switch is unreachable, update the PHI nodes - // (remove the entry to the default block) to reflect this. - Default->removePredecessor(OrigBlock); + // (remove all of the references to the default block) to reflect this. + const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases; + for (unsigned I = 0; I < NumDefaultEdges; ++I) + Default->removePredecessor(OrigBlock); // Use the most popular block as the new default, reducing the number of // cases. @@ -510,7 +576,7 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // If there are no cases left, just branch. if (Cases.empty()) { - BranchInst::Create(Default, CurBlock); + BranchInst::Create(Default, OrigBlock); SI->eraseFromParent(); // As all the cases have been replaced with a single branch, only keep // one entry in the PHI nodes. @@ -518,12 +584,12 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, PopSucc->removePredecessor(OrigBlock); return; } - } - unsigned NrOfDefaults = (SI->getDefaultDest() == Default) ? 1 : 0; - for (const auto &Case : SI->cases()) - if (Case.getCaseSuccessor() == Default) - NrOfDefaults++; + // If the condition was a PHI node with the switch block as a predecessor + // removing predecessors may have caused the condition to be erased. + // Getting the condition value again here protects against that. + Val = SI->getCondition(); + } // Create a new, empty default block so that the new hierarchy of // if-then statements go to this and the PHI nodes are happy. @@ -537,14 +603,14 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // If there are entries in any PHI nodes for the default edge, make sure // to update them as well. - fixPhis(Default, OrigBlock, NewDefault, NrOfDefaults); + fixPhis(Default, OrigBlock, NewDefault); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); // We are now done with the switch instruction, delete it. BasicBlock *OldDefault = SI->getDefaultDest(); - CurBlock->getInstList().erase(SI); + OrigBlock->getInstList().erase(SI); // If the Default block has no more predecessors just add it to DeleteList. if (pred_begin(OldDefault) == pred_end(OldDefault)) diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 23145e584751..cd2c81b6abc8 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -1,9 +1,8 @@ //===- Mem2Reg.cpp - The -mem2reg pass, a wrapper around the Utils lib ----===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp index 88d595ee02ab..c0b7edc547fd 100644 --- a/lib/Transforms/Utils/MetaRenamer.cpp +++ b/lib/Transforms/Utils/MetaRenamer.cpp @@ -1,9 +1,8 @@ //===- MetaRenamer.cpp - Rename everything with metasyntatic names --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index ae5e72ea4d30..c84beceee191 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -1,9 +1,8 @@ //===-- ModuleUtils.cpp - Functions to manipulate Modules -----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -28,44 +27,24 @@ static void appendToGlobalArray(const char *Array, Module &M, Function *F, // Get the current set of static global constructors and add the new ctor // to the list. SmallVector<Constant *, 16> CurrentCtors; - StructType *EltTy; + StructType *EltTy = StructType::get( + IRB.getInt32Ty(), PointerType::getUnqual(FnTy), IRB.getInt8PtrTy()); if (GlobalVariable *GVCtor = M.getNamedGlobal(Array)) { - ArrayType *ATy = cast<ArrayType>(GVCtor->getValueType()); - StructType *OldEltTy = cast<StructType>(ATy->getElementType()); - // Upgrade a 2-field global array type to the new 3-field format if needed. - if (Data && OldEltTy->getNumElements() < 3) - EltTy = StructType::get(IRB.getInt32Ty(), PointerType::getUnqual(FnTy), - IRB.getInt8PtrTy()); - else - EltTy = OldEltTy; if (Constant *Init = GVCtor->getInitializer()) { unsigned n = Init->getNumOperands(); CurrentCtors.reserve(n + 1); - for (unsigned i = 0; i != n; ++i) { - auto Ctor = cast<Constant>(Init->getOperand(i)); - if (EltTy != OldEltTy) - Ctor = - ConstantStruct::get(EltTy, Ctor->getAggregateElement((unsigned)0), - Ctor->getAggregateElement(1), - Constant::getNullValue(IRB.getInt8PtrTy())); - CurrentCtors.push_back(Ctor); - } + for (unsigned i = 0; i != n; ++i) + CurrentCtors.push_back(cast<Constant>(Init->getOperand(i))); } GVCtor->eraseFromParent(); - } else { - // Use the new three-field struct if there isn't one already. - EltTy = StructType::get(IRB.getInt32Ty(), PointerType::getUnqual(FnTy), - IRB.getInt8PtrTy()); } - // Build a 2 or 3 field global_ctor entry. We don't take a comdat key. + // Build a 3 field global_ctor entry. We don't take a comdat key. Constant *CSVals[3]; CSVals[0] = IRB.getInt32(Priority); CSVals[1] = F; - // FIXME: Drop support for the two element form in LLVM 4.0. - if (EltTy->getNumElements() >= 3) - CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy()) - : Constant::getNullValue(IRB.getInt8PtrTy()); + CSVals[2] = Data ? ConstantExpr::getPointerCast(Data, IRB.getInt8PtrTy()) + : Constant::getNullValue(IRB.getInt8PtrTy()); Constant *RuntimeCtorInit = ConstantStruct::get(EltTy, makeArrayRef(CSVals, EltTy->getNumElements())); @@ -127,36 +106,24 @@ void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) { appendToUsedList(M, "llvm.compiler.used", Values); } -Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) { - if (isa<Function>(FuncOrBitcast)) - return cast<Function>(FuncOrBitcast); - FuncOrBitcast->print(errs()); - errs() << '\n'; - std::string Err; - raw_string_ostream Stream(Err); - Stream << "Sanitizer interface function redefined: " << *FuncOrBitcast; - report_fatal_error(Err); -} - -Function *llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, - ArrayRef<Type *> InitArgTypes) { +FunctionCallee +llvm::declareSanitizerInitFunction(Module &M, StringRef InitName, + ArrayRef<Type *> InitArgTypes) { assert(!InitName.empty() && "Expected init function name"); - Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + return M.getOrInsertFunction( InitName, FunctionType::get(Type::getVoidTy(M.getContext()), InitArgTypes, false), - AttributeList())); - F->setLinkage(Function::ExternalLinkage); - return F; + AttributeList()); } -std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( +std::pair<Function *, FunctionCallee> llvm::createSanitizerCtorAndInitFunctions( Module &M, StringRef CtorName, StringRef InitName, ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, StringRef VersionCheckName) { assert(!InitName.empty() && "Expected init function name"); assert(InitArgs.size() == InitArgTypes.size() && "Sanitizer's init function expects different number of arguments"); - Function *InitFunction = + FunctionCallee InitFunction = declareSanitizerInitFunction(M, InitName, InitArgTypes); Function *Ctor = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), false), @@ -165,20 +132,19 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( IRBuilder<> IRB(ReturnInst::Create(M.getContext(), CtorBB)); IRB.CreateCall(InitFunction, InitArgs); if (!VersionCheckName.empty()) { - Function *VersionCheckFunction = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - VersionCheckName, FunctionType::get(IRB.getVoidTy(), {}, false), - AttributeList())); + FunctionCallee VersionCheckFunction = M.getOrInsertFunction( + VersionCheckName, FunctionType::get(IRB.getVoidTy(), {}, false), + AttributeList()); IRB.CreateCall(VersionCheckFunction, {}); } return std::make_pair(Ctor, InitFunction); } -std::pair<Function *, Function *> +std::pair<Function *, FunctionCallee> llvm::getOrCreateSanitizerCtorAndInitFunctions( Module &M, StringRef CtorName, StringRef InitName, ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, - function_ref<void(Function *, Function *)> FunctionsCreatedCallback, + function_ref<void(Function *, FunctionCallee)> FunctionsCreatedCallback, StringRef VersionCheckName) { assert(!CtorName.empty() && "Expected ctor function name"); @@ -189,7 +155,8 @@ llvm::getOrCreateSanitizerCtorAndInitFunctions( Ctor->getReturnType() == Type::getVoidTy(M.getContext())) return {Ctor, declareSanitizerInitFunction(M, InitName, InitArgTypes)}; - Function *Ctor, *InitFunction; + Function *Ctor; + FunctionCallee InitFunction; std::tie(Ctor, InitFunction) = llvm::createSanitizerCtorAndInitFunctions( M, CtorName, InitName, InitArgTypes, InitArgs, VersionCheckName); FunctionsCreatedCallback(Ctor, InitFunction); @@ -208,9 +175,10 @@ Function *llvm::getOrCreateInitFunction(Module &M, StringRef Name) { } return F; } - Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - Name, AttributeList(), Type::getVoidTy(M.getContext()))); - F->setLinkage(Function::ExternalLinkage); + Function *F = + cast<Function>(M.getOrInsertFunction(Name, AttributeList(), + Type::getVoidTy(M.getContext())) + .getCallee()); appendToGlobalCtors(M, F, 0); diff --git a/lib/Transforms/Utils/NameAnonGlobals.cpp b/lib/Transforms/Utils/NameAnonGlobals.cpp index 34dc1cccdd5b..ac8991e9d475 100644 --- a/lib/Transforms/Utils/NameAnonGlobals.cpp +++ b/lib/Transforms/Utils/NameAnonGlobals.cpp @@ -1,9 +1,8 @@ //===- NameAnonGlobals.cpp - ThinLTO Support: Name Unnamed Globals --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index 585ce6b4c118..bdf24d80bd17 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -1,9 +1,8 @@ //===-- PredicateInfo.cpp - PredicateInfo Builder--------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------===// // @@ -474,7 +473,8 @@ void PredicateInfo::buildPredicateInfo() { } for (auto &Assume : AC.assumptions()) { if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume)) - processAssume(II, II->getParent(), OpsToRename); + if (DT.isReachableFromEntry(II->getParent())) + processAssume(II, II->getParent(), OpsToRename); } // Now rename all our operations. renameUses(OpsToRename); @@ -489,8 +489,10 @@ void PredicateInfo::buildPredicateInfo() { // tricky (FIXME). static Function *getCopyDeclaration(Module *M, Type *Ty) { std::string Name = "llvm.ssa.copy." + utostr((uintptr_t) Ty); - return cast<Function>(M->getOrInsertFunction( - Name, getType(M->getContext(), Intrinsic::ssa_copy, Ty))); + return cast<Function>( + M->getOrInsertFunction(Name, + getType(M->getContext(), Intrinsic::ssa_copy, Ty)) + .getCallee()); } // Given the renaming stack, make all the operands currently on the stack real @@ -633,7 +635,7 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { // uses in the same instruction do not have a strict sort order // currently and will be considered equal. We could get rid of the // stable sort by creating one if we wanted. - std::stable_sort(OrderedUses.begin(), OrderedUses.end(), Compare); + llvm::stable_sort(OrderedUses, Compare); SmallVector<ValueDFS, 8> RenameStack; // For each use, sorted into dfs order, push values and replaces uses with // top of stack, which will represent the reaching def. diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 91e4f4254b3e..d58e1ea574ef 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -1,9 +1,8 @@ //===- PromoteMemoryToRegister.cpp - Convert allocas to registers ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -114,7 +113,6 @@ struct AllocaInfo { BasicBlock *OnlyBlock; bool OnlyUsedInOneBlock; - Value *AllocaPointerVal; TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares; void clear() { @@ -123,7 +121,6 @@ struct AllocaInfo { OnlyStore = nullptr; OnlyBlock = nullptr; OnlyUsedInOneBlock = true; - AllocaPointerVal = nullptr; DbgDeclares.clear(); } @@ -141,14 +138,12 @@ struct AllocaInfo { if (StoreInst *SI = dyn_cast<StoreInst>(User)) { // Remember the basic blocks which define new values for the alloca DefiningBlocks.push_back(SI->getParent()); - AllocaPointerVal = SI->getOperand(0); OnlyStore = SI; } else { LoadInst *LI = cast<LoadInst>(User); // Otherwise it must be a load instruction, keep track of variable // reads. UsingBlocks.push_back(LI->getParent()); - AllocaPointerVal = LI; } if (OnlyUsedInOneBlock) { @@ -254,11 +249,6 @@ struct PromoteMem2Reg { /// to. DenseMap<PHINode *, unsigned> PhiToAllocaMap; - /// If we are updating an AliasSetTracker, then for each alloca that is of - /// pointer type, we keep track of what to copyValue to the inserted PHI - /// nodes here. - std::vector<Value *> PointerAllocaValues; - /// For each alloca, we keep track of the dbg.declare intrinsic that /// describes it, if any, so that we can convert it to a dbg.value /// intrinsic if the alloca gets promoted. @@ -367,10 +357,8 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, for (auto UI = AI->user_begin(), E = AI->user_end(); UI != E;) { Instruction *UserInst = cast<Instruction>(*UI++); - if (!isa<LoadInst>(UserInst)) { - assert(UserInst == OnlyStore && "Should only have load/stores"); + if (UserInst == OnlyStore) continue; - } LoadInst *LI = cast<LoadInst>(UserInst); // Okay, if we have a load from the alloca, we want to replace it with the @@ -390,8 +378,7 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, Info.UsingBlocks.push_back(StoreBB); continue; } - } else if (LI->getParent() != StoreBB && - !DT.dominates(StoreBB, LI->getParent())) { + } else if (!DT.dominates(StoreBB, LI->getParent())) { // If the load and store are in different blocks, use BB dominance to // check their relationships. If the store doesn't dom the use, bail // out. @@ -429,14 +416,12 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB); DII->eraseFromParent(); - LBI.deleteValue(DII); } // Remove the (now dead) store and alloca. Info.OnlyStore->eraseFromParent(); LBI.deleteValue(Info.OnlyStore); AI->eraseFromParent(); - LBI.deleteValue(AI); return true; } @@ -488,11 +473,10 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, unsigned LoadIdx = LBI.getInstructionIndex(LI); // Find the nearest store that has a lower index than this load. - StoresByIndexTy::iterator I = - std::lower_bound(StoresByIndex.begin(), StoresByIndex.end(), - std::make_pair(LoadIdx, - static_cast<StoreInst *>(nullptr)), - less_first()); + StoresByIndexTy::iterator I = llvm::lower_bound( + StoresByIndex, + std::make_pair(LoadIdx, static_cast<StoreInst *>(nullptr)), + less_first()); if (I == StoresByIndex.begin()) { if (StoresByIndex.empty()) // If there are no stores, the load takes the undef value. @@ -535,13 +519,10 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, } AI->eraseFromParent(); - LBI.deleteValue(AI); // The alloca's debuginfo can be removed as well. - for (DbgVariableIntrinsic *DII : Info.DbgDeclares) { + for (DbgVariableIntrinsic *DII : Info.DbgDeclares) DII->eraseFromParent(); - LBI.deleteValue(DII); - } ++NumLocalPromoted; return true; @@ -620,8 +601,8 @@ void PromoteMem2Reg::run() { // dead phi nodes. // Unique the set of defining blocks for efficient lookup. - SmallPtrSet<BasicBlock *, 32> DefBlocks; - DefBlocks.insert(Info.DefiningBlocks.begin(), Info.DefiningBlocks.end()); + SmallPtrSet<BasicBlock *, 32> DefBlocks(Info.DefiningBlocks.begin(), + Info.DefiningBlocks.end()); // Determine which blocks the value is live in. These are blocks which lead // to uses. @@ -636,10 +617,9 @@ void PromoteMem2Reg::run() { IDF.setDefiningBlocks(DefBlocks); SmallVector<BasicBlock *, 32> PHIBlocks; IDF.calculate(PHIBlocks); - if (PHIBlocks.size() > 1) - llvm::sort(PHIBlocks, [this](BasicBlock *A, BasicBlock *B) { - return BBNumbers.lookup(A) < BBNumbers.lookup(B); - }); + llvm::sort(PHIBlocks, [this](BasicBlock *A, BasicBlock *B) { + return BBNumbers.find(A)->second < BBNumbers.find(B)->second; + }); unsigned CurrentVersion = 0; for (BasicBlock *BB : PHIBlocks) @@ -751,7 +731,7 @@ void PromoteMem2Reg::run() { // basic blocks. Start by sorting the incoming predecessors for efficient // access. auto CompareBBNumbers = [this](BasicBlock *A, BasicBlock *B) { - return BBNumbers.lookup(A) < BBNumbers.lookup(B); + return BBNumbers.find(A)->second < BBNumbers.find(B)->second; }; llvm::sort(Preds, CompareBBNumbers); @@ -759,9 +739,8 @@ void PromoteMem2Reg::run() { // them from the Preds list. for (unsigned i = 0, e = SomePHI->getNumIncomingValues(); i != e; ++i) { // Do a log(n) search of the Preds list for the entry we want. - SmallVectorImpl<BasicBlock *>::iterator EntIt = std::lower_bound( - Preds.begin(), Preds.end(), SomePHI->getIncomingBlock(i), - CompareBBNumbers); + SmallVectorImpl<BasicBlock *>::iterator EntIt = llvm::lower_bound( + Preds, SomePHI->getIncomingBlock(i), CompareBBNumbers); assert(EntIt != Preds.end() && *EntIt == SomePHI->getIncomingBlock(i) && "PHI node has entry for a block which is not a predecessor!"); @@ -825,14 +804,11 @@ void PromoteMem2Reg::ComputeLiveInBlocks( break; } - if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - if (LI->getOperand(0) != AI) - continue; - + if (LoadInst *LI = dyn_cast<LoadInst>(I)) // Okay, we found a load before a store to the alloca. It is actually // live into this block. - break; - } + if (LI->getOperand(0) == AI) + break; } } diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp index 9e5fb0e7172d..bffdd115d940 100644 --- a/lib/Transforms/Utils/SSAUpdater.cpp +++ b/lib/Transforms/Utils/SSAUpdater.cpp @@ -1,9 +1,8 @@ //===- SSAUpdater.cpp - Unstructured SSA Update Tool ----------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -350,8 +349,7 @@ LoadAndStorePromoter(ArrayRef<const Instruction *> Insts, SSA.Initialize(SomeVal->getType(), BaseName); } -void LoadAndStorePromoter:: -run(const SmallVectorImpl<Instruction *> &Insts) const { +void LoadAndStorePromoter::run(const SmallVectorImpl<Instruction *> &Insts) { // First step: bucket up uses of the alloca by the block they occur in. // This is important because we have to handle multiple defs/uses in a block // ourselves: SSAUpdater is purely for cross-block references. diff --git a/lib/Transforms/Utils/SSAUpdaterBulk.cpp b/lib/Transforms/Utils/SSAUpdaterBulk.cpp index 397bac2940a4..917d5e0a1ef0 100644 --- a/lib/Transforms/Utils/SSAUpdaterBulk.cpp +++ b/lib/Transforms/Utils/SSAUpdaterBulk.cpp @@ -1,9 +1,8 @@ //===- SSAUpdaterBulk.cpp - Unstructured SSA Update Tool ------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/SanitizerStats.cpp b/lib/Transforms/Utils/SanitizerStats.cpp index 8c23957ac43e..a1313c77ed77 100644 --- a/lib/Transforms/Utils/SanitizerStats.cpp +++ b/lib/Transforms/Utils/SanitizerStats.cpp @@ -1,9 +1,8 @@ //===- SanitizerStats.cpp - Sanitizer statistics gathering ----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -57,8 +56,8 @@ void SanitizerStatReport::create(IRBuilder<> &B, SanitizerStatKind SK) { FunctionType *StatReportTy = FunctionType::get(B.getVoidTy(), Int8PtrTy, false); - Constant *StatReport = M->getOrInsertFunction( - "__sanitizer_stat_report", StatReportTy); + FunctionCallee StatReport = + M->getOrInsertFunction("__sanitizer_stat_report", StatReportTy); auto InitAddr = ConstantExpr::getGetElementPtr( EmptyModuleStatsTy, ModuleStatsGV, @@ -98,8 +97,8 @@ void SanitizerStatReport::finish() { IRBuilder<> B(BB); FunctionType *StatInitTy = FunctionType::get(VoidTy, Int8PtrTy, false); - Constant *StatInit = M->getOrInsertFunction( - "__sanitizer_stat_init", StatInitTy); + FunctionCallee StatInit = + M->getOrInsertFunction("__sanitizer_stat_init", StatInitTy); B.CreateCall(StatInit, ConstantExpr::getBitCast(NewModuleStatsGV, Int8PtrTy)); B.CreateRetVoid(); diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 03b73954321d..11651d040dc0 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1,9 +1,8 @@ //===- SimplifyCFG.cpp - Code to perform CFG simplification ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -26,8 +25,9 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -66,6 +66,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -292,9 +293,13 @@ isProfitableToFoldUnconditional(BranchInst *SI1, BranchInst *SI2, /// will be the same as those coming in from ExistPred, an existing predecessor /// of Succ. static void AddPredecessorToBlock(BasicBlock *Succ, BasicBlock *NewPred, - BasicBlock *ExistPred) { + BasicBlock *ExistPred, + MemorySSAUpdater *MSSAU = nullptr) { for (PHINode &PN : Succ->phis()) PN.addIncoming(PN.getIncomingValueForBlock(ExistPred), NewPred); + if (MSSAU) + if (auto *MPhi = MSSAU->getMemorySSA()->getMemoryAccess(Succ)) + MPhi->addIncoming(MPhi->getIncomingValueForBlock(ExistPred), NewPred); } /// Compute an abstract "cost" of speculating the given instruction, @@ -670,7 +675,8 @@ private: } // end anonymous namespace -static void EraseTerminatorAndDCECond(Instruction *TI) { +static void EraseTerminatorAndDCECond(Instruction *TI, + MemorySSAUpdater *MSSAU = nullptr) { Instruction *Cond = nullptr; if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Cond = dyn_cast<Instruction>(SI->getCondition()); @@ -683,7 +689,7 @@ static void EraseTerminatorAndDCECond(Instruction *TI) { TI->eraseFromParent(); if (Cond) - RecursivelyDeleteTriviallyDeadInstructions(Cond); + RecursivelyDeleteTriviallyDeadInstructions(Cond, nullptr, MSSAU); } /// Return true if the specified terminator checks @@ -858,7 +864,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( return true; } - SwitchInst *SI = cast<SwitchInst>(TI); + SwitchInstProfUpdateWrapper SI = *cast<SwitchInst>(TI); // Okay, TI has cases that are statically dead, prune them away. SmallPtrSet<Constant *, 16> DeadCases; for (unsigned i = 0, e = PredCases.size(); i != e; ++i) @@ -867,30 +873,13 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( LLVM_DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() << "Through successor TI: " << *TI); - // Collect branch weights into a vector. - SmallVector<uint32_t, 8> Weights; - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); - bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases()); - if (HasWeight) - for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; - ++MD_i) { - ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(MD_i)); - Weights.push_back(CI->getValue().getZExtValue()); - } for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) { --i; if (DeadCases.count(i->getCaseValue())) { - if (HasWeight) { - std::swap(Weights[i->getCaseIndex() + 1], Weights.back()); - Weights.pop_back(); - } i->getCaseSuccessor()->removePredecessor(TI->getParent()); - SI->removeCase(i); + SI.removeCase(i); } } - if (HasWeight && Weights.size() >= 2) - setBranchWeights(SI, Weights); - LLVM_DEBUG(dbgs() << "Leaving: " << *TI << "\n"); return true; } @@ -1266,8 +1255,10 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, while (isa<DbgInfoIntrinsic>(I2)) I2 = &*BB2_Itr++; } + // FIXME: Can we define a safety predicate for CallBr? if (isa<PHINode>(I1) || !I1->isIdenticalToWhenDefined(I2) || - (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2))) + (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2)) || + isa<CallBrInst>(I1)) return false; BasicBlock *BIParent = BI->getParent(); @@ -1350,9 +1341,14 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, HoistTerminator: // It may not be possible to hoist an invoke. + // FIXME: Can we define a safety predicate for CallBr? if (isa<InvokeInst>(I1) && !isSafeToHoistInvoke(BB1, BB2, I1, I2)) return Changed; + // TODO: callbr hoisting currently disabled pending further study. + if (isa<CallBrInst>(I1)) + return Changed; + for (BasicBlock *Succ : successors(BB1)) { for (PHINode &PN : Succ->phis()) { Value *BB1V = PN.getIncomingValueForBlock(BB1); @@ -1432,9 +1428,10 @@ HoistTerminator: static bool canSinkInstructions( ArrayRef<Instruction *> Insts, DenseMap<Instruction *, SmallVector<Value *, 4>> &PHIOperands) { - // Prune out obviously bad instructions to move. Any non-store instruction - // must have exactly one use, and we check later that use is by a single, - // common PHI instruction in the successor. + // Prune out obviously bad instructions to move. Each instruction must have + // exactly zero or one use, and we check later that use is by a single, common + // PHI instruction in the successor. + bool HasUse = !Insts.front()->user_empty(); for (auto *I : Insts) { // These instructions may change or break semantics if moved. if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || @@ -1444,13 +1441,14 @@ static bool canSinkInstructions( // Conservatively return false if I is an inline-asm instruction. Sinking // and merging inline-asm instructions can potentially create arguments // that cannot satisfy the inline-asm constraints. - if (const auto *C = dyn_cast<CallInst>(I)) + if (const auto *C = dyn_cast<CallBase>(I)) if (C->isInlineAsm()) return false; - // Everything must have only one use too, apart from stores which - // have no uses. - if (!isa<StoreInst>(I) && !I->hasOneUse()) + // Each instruction must have zero or one use. + if (HasUse && !I->hasOneUse()) + return false; + if (!HasUse && !I->user_empty()) return false; } @@ -1459,11 +1457,11 @@ static bool canSinkInstructions( if (!I->isSameOperationAs(I0)) return false; - // All instructions in Insts are known to be the same opcode. If they aren't - // stores, check the only user of each is a PHI or in the same block as the - // instruction, because if a user is in the same block as an instruction - // we're contemplating sinking, it must already be determined to be sinkable. - if (!isa<StoreInst>(I0)) { + // All instructions in Insts are known to be the same opcode. If they have a + // use, check that the only user is a PHI or in the same block as the + // instruction, because if a user is in the same block as an instruction we're + // contemplating sinking, it must already be determined to be sinkable. + if (HasUse) { auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); auto *Succ = I0->getParent()->getTerminator()->getSuccessor(0); if (!all_of(Insts, [&PNUse,&Succ](const Instruction *I) -> bool { @@ -1507,7 +1505,7 @@ static bool canSinkInstructions( // We can't create a PHI from this GEP. return false; // Don't create indirect calls! The called value is the final operand. - if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OI == OE - 1) { + if (isa<CallBase>(I0) && OI == OE - 1) { // FIXME: if the call was *already* indirect, we should do this. return false; } @@ -1541,7 +1539,7 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { // it is slightly over-aggressive - it gets confused by commutative instructions // so double-check it here. Instruction *I0 = Insts.front(); - if (!isa<StoreInst>(I0)) { + if (!I0->user_empty()) { auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); if (!all_of(Insts, [&PNUse](const Instruction *I) -> bool { auto *U = cast<Instruction>(*I->user_begin()); @@ -1599,11 +1597,10 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { I0->andIRFlags(I); } - if (!isa<StoreInst>(I0)) { + if (!I0->user_empty()) { // canSinkLastInstruction checked that all instructions were used by // one and only one PHI node. Find that now, RAUW it to our common // instruction and nuke it. - assert(I0->hasOneUse()); auto *PN = cast<PHINode>(*I0->user_begin()); PN->replaceAllUsesWith(I0); PN->eraseFromParent(); @@ -2203,7 +2200,8 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL, BasicBlock *EdgeBB = BasicBlock::Create(BB->getContext(), RealDest->getName() + ".critedge", RealDest->getParent(), RealDest); - BranchInst::Create(RealDest, EdgeBB); + BranchInst *CritEdgeBranch = BranchInst::Create(RealDest, EdgeBB); + CritEdgeBranch->setDebugLoc(BI->getDebugLoc()); // Update PHI nodes. AddPredecessorToBlock(RealDest, EdgeBB, BB); @@ -2539,7 +2537,8 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, /// If this basic block is simple enough, and if a predecessor branches to us /// and one of our successors, fold the block into the predecessor and use /// logical operations to pick the right destination. -bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { +bool llvm::FoldBranchToCommonDest(BranchInst *BI, MemorySSAUpdater *MSSAU, + unsigned BonusInstThreshold) { BasicBlock *BB = BI->getParent(); const unsigned PredCount = pred_size(BB); @@ -2594,7 +2593,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // unconditionally. We denote all involved instructions except the condition // as "bonus instructions", and only allow this transformation when the // number of the bonus instructions we'll need to create when cloning into - // each predecessor does not exceed a certain threshold. + // each predecessor does not exceed a certain threshold. unsigned NumBonusInsts = 0; for (auto I = BB->begin(); Cond != &*I; ++I) { // Ignore dbg intrinsics. @@ -2611,7 +2610,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // and Cond. // Account for the cost of duplicating this instruction into each - // predecessor. + // predecessor. NumBonusInsts += PredCount; // Early exits once we reach the limit. if (NumBonusInsts > BonusInstThreshold) @@ -2750,7 +2749,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { (SuccFalseWeight + SuccTrueWeight) + PredTrueWeight * SuccFalseWeight); } - AddPredecessorToBlock(TrueDest, PredBlock, BB); + AddPredecessorToBlock(TrueDest, PredBlock, BB, MSSAU); PBI->setSuccessor(0, TrueDest); } if (PBI->getSuccessor(1) == BB) { @@ -2765,7 +2764,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // FalseWeight is FalseWeight for PBI * FalseWeight for BI. NewWeights.push_back(PredFalseWeight * SuccFalseWeight); } - AddPredecessorToBlock(FalseDest, PredBlock, BB); + AddPredecessorToBlock(FalseDest, PredBlock, BB, MSSAU); PBI->setSuccessor(1, FalseDest); } if (NewWeights.size() == 2) { @@ -2810,12 +2809,17 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { } } // Update PHI Node. - PHIs[i]->setIncomingValue(PHIs[i]->getBasicBlockIndex(PBI->getParent()), - MergedCond); + PHIs[i]->setIncomingValueForBlock(PBI->getParent(), MergedCond); } + + // PBI is changed to branch to TrueDest below. Remove itself from + // potential phis from all other successors. + if (MSSAU) + MSSAU->changeCondBranchToUnconditionalTo(PBI, TrueDest); + // Change PBI from Conditional to Unconditional. BranchInst *New_PBI = BranchInst::Create(TrueDest, PBI); - EraseTerminatorAndDCECond(PBI); + EraseTerminatorAndDCECond(PBI, MSSAU); PBI = New_PBI; } @@ -3430,7 +3434,7 @@ static bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, KeepEdge2 = nullptr; else Succ->removePredecessor(OldTerm->getParent(), - /*DontDeleteUselessPHIs=*/true); + /*KeepOneInputPHIs=*/true); } IRBuilder<> Builder(OldTerm); @@ -3622,20 +3626,16 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( // the switch to the merge point on the compared value. BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "switch.edge", BB->getParent(), BB); - SmallVector<uint64_t, 8> Weights; - bool HasWeights = HasBranchWeights(SI); - if (HasWeights) { - GetBranchWeights(SI, Weights); - if (Weights.size() == 1 + SI->getNumCases()) { - // Split weight for default case to case for "Cst". - Weights[0] = (Weights[0] + 1) >> 1; - Weights.push_back(Weights[0]); - - SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - setBranchWeights(SI, MDWeights); + { + SwitchInstProfUpdateWrapper SIW(*SI); + auto W0 = SIW.getSuccessorWeight(0); + SwitchInstProfUpdateWrapper::CaseWeightOpt NewW; + if (W0) { + NewW = ((uint64_t(*W0) + 1) >> 1); + SIW.setSuccessorWeight(0, *NewW); } + SIW.addCase(Cst, NewBB, NewW); } - SI->addCase(Cst, NewBB); // NewBB branches to the phi block, add the uncond branch and the phi entry. Builder.SetInsertPoint(NewBB); @@ -4184,24 +4184,28 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { Changed = true; } } else { + Value* Cond = BI->getCondition(); if (BI->getSuccessor(0) == BB) { + Builder.CreateAssumption(Builder.CreateNot(Cond)); Builder.CreateBr(BI->getSuccessor(1)); EraseTerminatorAndDCECond(BI); } else if (BI->getSuccessor(1) == BB) { + Builder.CreateAssumption(Cond); Builder.CreateBr(BI->getSuccessor(0)); EraseTerminatorAndDCECond(BI); Changed = true; } } } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { - for (auto i = SI->case_begin(), e = SI->case_end(); i != e;) { + SwitchInstProfUpdateWrapper SU(*SI); + for (auto i = SU->case_begin(), e = SU->case_end(); i != e;) { if (i->getCaseSuccessor() != BB) { ++i; continue; } - BB->removePredecessor(SI->getParent()); - i = SI->removeCase(i); - e = SI->case_end(); + BB->removePredecessor(SU->getParent()); + i = SU.removeCase(i); + e = SU->case_end(); Changed = true; } } else if (auto *II = dyn_cast<InvokeInst>(TI)) { @@ -4435,33 +4439,20 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, return true; } - SmallVector<uint64_t, 8> Weights; - bool HasWeight = HasBranchWeights(SI); - if (HasWeight) { - GetBranchWeights(SI, Weights); - HasWeight = (Weights.size() == 1 + SI->getNumCases()); - } + if (DeadCases.empty()) + return false; - // Remove dead cases from the switch. + SwitchInstProfUpdateWrapper SIW(*SI); for (ConstantInt *DeadCase : DeadCases) { SwitchInst::CaseIt CaseI = SI->findCaseValue(DeadCase); assert(CaseI != SI->case_default() && "Case was not found. Probably mistake in DeadCases forming."); - if (HasWeight) { - std::swap(Weights[CaseI->getCaseIndex() + 1], Weights.back()); - Weights.pop_back(); - } - // Prune unused values from PHI nodes. CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); - SI->removeCase(CaseI); - } - if (HasWeight && Weights.size() >= 2) { - SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - setBranchWeights(SI, MDWeights); + SIW.removeCase(CaseI); } - return !DeadCases.empty(); + return true; } /// If BB would be eligible for simplification by @@ -5034,7 +5025,7 @@ SwitchLookupTable::SwitchLookupTable( ArrayType *ArrayTy = ArrayType::get(ValueType, TableSize); Constant *Initializer = ConstantArray::get(ArrayTy, TableContents); - Array = new GlobalVariable(M, ArrayTy, /*constant=*/true, + Array = new GlobalVariable(M, ArrayTy, /*isConstant=*/true, GlobalVariable::PrivateLinkage, Initializer, "switch.table." + FuncName); Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); @@ -5091,7 +5082,9 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { Value *GEPIndices[] = {Builder.getInt32(0), Index}; Value *GEP = Builder.CreateInBoundsGEP(Array->getValueType(), Array, GEPIndices, "switch.gep"); - return Builder.CreateLoad(GEP, "switch.load"); + return Builder.CreateLoad( + cast<ArrayType>(Array->getValueType())->getElementType(), GEP, + "switch.load"); } } llvm_unreachable("Unknown lookup table kind!"); @@ -5425,7 +5418,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // We cached PHINodes in PHIs. To avoid accessing deleted PHINodes later, // do not delete PHINodes here. SI->getDefaultDest()->removePredecessor(SI->getParent(), - /*DontDeleteUselessPHIs=*/true); + /*KeepOneInputPHIs=*/true); } bool ReturnedEarly = false; @@ -5533,25 +5526,23 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, // Now we have signed numbers that have been shifted so that, given enough // precision, there are no negative values. Since the rest of the transform // is bitwise only, we switch now to an unsigned representation. - uint64_t GCD = 0; - for (auto &V : Values) - GCD = GreatestCommonDivisor64(GCD, (uint64_t)V); - // This transform can be done speculatively because it is so cheap - it results - // in a single rotate operation being inserted. This can only happen if the - // factor extracted is a power of 2. - // FIXME: If the GCD is an odd number we can multiply by the multiplicative - // inverse of GCD and then perform this transform. + // This transform can be done speculatively because it is so cheap - it + // results in a single rotate operation being inserted. // FIXME: It's possible that optimizing a switch on powers of two might also // be beneficial - flag values are often powers of two and we could use a CLZ // as the key function. - if (GCD <= 1 || !isPowerOf2_64(GCD)) - // No common divisor found or too expensive to compute key function. - return false; - unsigned Shift = Log2_64(GCD); + // countTrailingZeros(0) returns 64. As Values is guaranteed to have more than + // one element and LLVM disallows duplicate cases, Shift is guaranteed to be + // less than 64. + unsigned Shift = 64; for (auto &V : Values) - V = (int64_t)((uint64_t)V >> Shift); + Shift = std::min(Shift, countTrailingZeros((uint64_t)V)); + assert(Shift < 64); + if (Shift > 0) + for (auto &V : Values) + V = (int64_t)((uint64_t)V >> Shift); if (!isSwitchDense(Values)) // Transform didn't create a dense switch. @@ -5796,7 +5787,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) + if (FoldBranchToCommonDest(BI, nullptr, Options.BonusInstThreshold)) return requestResimplify(); return false; } @@ -5860,7 +5851,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) + if (FoldBranchToCommonDest(BI, nullptr, Options.BonusInstThreshold)) return requestResimplify(); // We have a conditional branch to two blocks that are only reachable diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index 7faf291e73d9..cbb114f9a47a 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -1,9 +1,8 @@ //===-- SimplifyIndVar.cpp - Induction variable simplification ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -23,6 +22,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -80,7 +80,8 @@ namespace { bool eliminateIdentitySCEV(Instruction *UseInst, Instruction *IVOperand); bool replaceIVUserWithLoopInvariant(Instruction *UseInst); - bool eliminateOverflowIntrinsic(CallInst *CI); + bool eliminateOverflowIntrinsic(WithOverflowInst *WO); + bool eliminateSaturatingIntrinsic(SaturatingInst *SI); bool eliminateTrunc(TruncInst *TI); bool eliminateIVUser(Instruction *UseInst, Instruction *IVOperand); bool makeIVComparisonInvariant(ICmpInst *ICmp, Value *IVOperand); @@ -401,61 +402,29 @@ void SimplifyIndvar::simplifyIVRemainder(BinaryOperator *Rem, Value *IVOperand, replaceSRemWithURem(Rem); } -bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { - auto *F = CI->getCalledFunction(); - if (!F) - return false; - - typedef const SCEV *(ScalarEvolution::*OperationFunctionTy)( - const SCEV *, const SCEV *, SCEV::NoWrapFlags, unsigned); - typedef const SCEV *(ScalarEvolution::*ExtensionFunctionTy)( - const SCEV *, Type *, unsigned); - - OperationFunctionTy Operation; - ExtensionFunctionTy Extension; - - Instruction::BinaryOps RawOp; - - // We always have exactly one of nsw or nuw. If NoSignedOverflow is false, we - // have nuw. - bool NoSignedOverflow; - - switch (F->getIntrinsicID()) { +static bool willNotOverflow(ScalarEvolution *SE, Instruction::BinaryOps BinOp, + bool Signed, const SCEV *LHS, const SCEV *RHS) { + const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *, + SCEV::NoWrapFlags, unsigned); + switch (BinOp) { default: - return false; - - case Intrinsic::sadd_with_overflow: - Operation = &ScalarEvolution::getAddExpr; - Extension = &ScalarEvolution::getSignExtendExpr; - RawOp = Instruction::Add; - NoSignedOverflow = true; - break; - - case Intrinsic::uadd_with_overflow: + llvm_unreachable("Unsupported binary op"); + case Instruction::Add: Operation = &ScalarEvolution::getAddExpr; - Extension = &ScalarEvolution::getZeroExtendExpr; - RawOp = Instruction::Add; - NoSignedOverflow = false; break; - - case Intrinsic::ssub_with_overflow: + case Instruction::Sub: Operation = &ScalarEvolution::getMinusSCEV; - Extension = &ScalarEvolution::getSignExtendExpr; - RawOp = Instruction::Sub; - NoSignedOverflow = true; break; - - case Intrinsic::usub_with_overflow: - Operation = &ScalarEvolution::getMinusSCEV; - Extension = &ScalarEvolution::getZeroExtendExpr; - RawOp = Instruction::Sub; - NoSignedOverflow = false; + case Instruction::Mul: + Operation = &ScalarEvolution::getMulExpr; break; } - const SCEV *LHS = SE->getSCEV(CI->getArgOperand(0)); - const SCEV *RHS = SE->getSCEV(CI->getArgOperand(1)); + const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) = + Signed ? &ScalarEvolution::getSignExtendExpr + : &ScalarEvolution::getZeroExtendExpr; + // Check ext(LHS op RHS) == ext(LHS) op ext(RHS) auto *NarrowTy = cast<IntegerType>(LHS->getType()); auto *WideTy = IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2); @@ -466,27 +435,32 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { const SCEV *B = (SE->*Operation)((SE->*Extension)(LHS, WideTy, 0), (SE->*Extension)(RHS, WideTy, 0), SCEV::FlagAnyWrap, 0); + return A == B; +} - if (A != B) +bool SimplifyIndvar::eliminateOverflowIntrinsic(WithOverflowInst *WO) { + const SCEV *LHS = SE->getSCEV(WO->getLHS()); + const SCEV *RHS = SE->getSCEV(WO->getRHS()); + if (!willNotOverflow(SE, WO->getBinaryOp(), WO->isSigned(), LHS, RHS)) return false; // Proved no overflow, nuke the overflow check and, if possible, the overflow // intrinsic as well. BinaryOperator *NewResult = BinaryOperator::Create( - RawOp, CI->getArgOperand(0), CI->getArgOperand(1), "", CI); + WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), "", WO); - if (NoSignedOverflow) + if (WO->isSigned()) NewResult->setHasNoSignedWrap(true); else NewResult->setHasNoUnsignedWrap(true); SmallVector<ExtractValueInst *, 4> ToDelete; - for (auto *U : CI->users()) { + for (auto *U : WO->users()) { if (auto *EVI = dyn_cast<ExtractValueInst>(U)) { if (EVI->getIndices()[0] == 1) - EVI->replaceAllUsesWith(ConstantInt::getFalse(CI->getContext())); + EVI->replaceAllUsesWith(ConstantInt::getFalse(WO->getContext())); else { assert(EVI->getIndices()[0] == 0 && "Only two possibilities!"); EVI->replaceAllUsesWith(NewResult); @@ -498,9 +472,28 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { for (auto *EVI : ToDelete) EVI->eraseFromParent(); - if (CI->use_empty()) - CI->eraseFromParent(); + if (WO->use_empty()) + WO->eraseFromParent(); + + return true; +} + +bool SimplifyIndvar::eliminateSaturatingIntrinsic(SaturatingInst *SI) { + const SCEV *LHS = SE->getSCEV(SI->getLHS()); + const SCEV *RHS = SE->getSCEV(SI->getRHS()); + if (!willNotOverflow(SE, SI->getBinaryOp(), SI->isSigned(), LHS, RHS)) + return false; + + BinaryOperator *BO = BinaryOperator::Create( + SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI); + if (SI->isSigned()) + BO->setHasNoSignedWrap(); + else + BO->setHasNoUnsignedWrap(); + SI->replaceAllUsesWith(BO); + DeadInsts.emplace_back(SI); + Changed = true; return true; } @@ -548,20 +541,19 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { if (isa<Instruction>(U) && !DT->isReachableFromEntry(cast<Instruction>(U)->getParent())) continue; - if (ICmpInst *ICI = dyn_cast<ICmpInst>(U)) { - if (ICI->getOperand(0) == TI && L->isLoopInvariant(ICI->getOperand(1))) { - assert(L->contains(ICI->getParent()) && "LCSSA form broken?"); - // If we cannot get rid of trunc, bail. - if (ICI->isSigned() && !DoesSExtCollapse) - return false; - if (ICI->isUnsigned() && !DoesZExtCollapse) - return false; - // For equality, either signed or unsigned works. - ICmpUsers.push_back(ICI); - } else - return false; - } else + ICmpInst *ICI = dyn_cast<ICmpInst>(U); + if (!ICI) return false; + assert(L->contains(ICI->getParent()) && "LCSSA form broken?"); + if (!(ICI->getOperand(0) == TI && L->isLoopInvariant(ICI->getOperand(1))) && + !(ICI->getOperand(1) == TI && L->isLoopInvariant(ICI->getOperand(0)))) return false; + // If we cannot get rid of trunc, bail. + if (ICI->isSigned() && !DoesSExtCollapse) + return false; + if (ICI->isUnsigned() && !DoesZExtCollapse) + return false; + // For equality, either signed or unsigned works. + ICmpUsers.push_back(ICI); } auto CanUseZExt = [&](ICmpInst *ICI) { @@ -584,7 +576,8 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { }; // Replace all comparisons against trunc with comparisons against IV. for (auto *ICI : ICmpUsers) { - auto *Op1 = ICI->getOperand(1); + bool IsSwapped = L->isLoopInvariant(ICI->getOperand(0)); + auto *Op1 = IsSwapped ? ICI->getOperand(0) : ICI->getOperand(1); Instruction *Ext = nullptr; // For signed/unsigned predicate, replace the old comparison with comparison // of immediate IV against sext/zext of the invariant argument. If we can @@ -593,6 +586,7 @@ bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { // TODO: If we see a signed comparison which can be turned into unsigned, // we can do it here for canonicalization purposes. ICmpInst::Predicate Pred = ICI->getPredicate(); + if (IsSwapped) Pred = ICmpInst::getSwappedPredicate(Pred); if (CanUseZExt(ICI)) { assert(DoesZExtCollapse && "Unprofitable zext?"); Ext = new ZExtInst(Op1, IVTy, "zext", ICI); @@ -636,8 +630,12 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, return eliminateSDiv(Bin); } - if (auto *CI = dyn_cast<CallInst>(UseInst)) - if (eliminateOverflowIntrinsic(CI)) + if (auto *WO = dyn_cast<WithOverflowInst>(UseInst)) + if (eliminateOverflowIntrinsic(WO)) + return true; + + if (auto *SI = dyn_cast<SaturatingInst>(UseInst)) + if (eliminateSaturatingIntrinsic(SI)) return true; if (auto *TI = dyn_cast<TruncInst>(UseInst)) @@ -730,59 +728,31 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, /// unsigned-overflow. Returns true if anything changed, false otherwise. bool SimplifyIndvar::strengthenOverflowingOperation(BinaryOperator *BO, Value *IVOperand) { - // Fastpath: we don't have any work to do if `BO` is `nuw` and `nsw`. if (BO->hasNoUnsignedWrap() && BO->hasNoSignedWrap()) return false; - const SCEV *(ScalarEvolution::*GetExprForBO)(const SCEV *, const SCEV *, - SCEV::NoWrapFlags, unsigned); - switch (BO->getOpcode()) { - default: + if (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Sub && + BO->getOpcode() != Instruction::Mul) return false; - case Instruction::Add: - GetExprForBO = &ScalarEvolution::getAddExpr; - break; - - case Instruction::Sub: - GetExprForBO = &ScalarEvolution::getMinusSCEV; - break; - - case Instruction::Mul: - GetExprForBO = &ScalarEvolution::getMulExpr; - break; - } - - unsigned BitWidth = cast<IntegerType>(BO->getType())->getBitWidth(); - Type *WideTy = IntegerType::get(BO->getContext(), BitWidth * 2); const SCEV *LHS = SE->getSCEV(BO->getOperand(0)); const SCEV *RHS = SE->getSCEV(BO->getOperand(1)); - bool Changed = false; - if (!BO->hasNoUnsignedWrap()) { - const SCEV *ExtendAfterOp = SE->getZeroExtendExpr(SE->getSCEV(BO), WideTy); - const SCEV *OpAfterExtend = (SE->*GetExprForBO)( - SE->getZeroExtendExpr(LHS, WideTy), SE->getZeroExtendExpr(RHS, WideTy), - SCEV::FlagAnyWrap, 0u); - if (ExtendAfterOp == OpAfterExtend) { - BO->setHasNoUnsignedWrap(); - SE->forgetValue(BO); - Changed = true; - } + if (!BO->hasNoUnsignedWrap() && + willNotOverflow(SE, BO->getOpcode(), /* Signed */ false, LHS, RHS)) { + BO->setHasNoUnsignedWrap(); + SE->forgetValue(BO); + Changed = true; } - if (!BO->hasNoSignedWrap()) { - const SCEV *ExtendAfterOp = SE->getSignExtendExpr(SE->getSCEV(BO), WideTy); - const SCEV *OpAfterExtend = (SE->*GetExprForBO)( - SE->getSignExtendExpr(LHS, WideTy), SE->getSignExtendExpr(RHS, WideTy), - SCEV::FlagAnyWrap, 0u); - if (ExtendAfterOp == OpAfterExtend) { - BO->setHasNoSignedWrap(); - SE->forgetValue(BO); - Changed = true; - } + if (!BO->hasNoSignedWrap() && + willNotOverflow(SE, BO->getOpcode(), /* Signed */ true, LHS, RHS)) { + BO->setHasNoSignedWrap(); + SE->forgetValue(BO); + Changed = true; } return Changed; diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 1bb26caa2af2..e0def81d5eee 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1,9 +1,8 @@ //===------ SimplifyLibCalls.cpp - Library calls simplifier ---------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -17,8 +16,10 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" @@ -35,6 +36,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/Transforms/Utils/SizeOpts.h" using namespace llvm; using namespace PatternMatch; @@ -105,6 +107,12 @@ static bool callHasFloatingPointArgument(const CallInst *CI) { }); } +static bool callHasFP128Argument(const CallInst *CI) { + return any_of(CI->operands(), [](const Use &OI) { + return OI->getType()->isFP128Ty(); + }); +} + static Value *convertStrToNumber(CallInst *CI, StringRef &Str, int64_t Base) { if (Base < 2 || Base > 36) // handle special zero base @@ -334,11 +342,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { return ConstantInt::get(CI->getType(), Str1.compare(Str2)); if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x - return B.CreateNeg( - B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); + return B.CreateNeg(B.CreateZExt( + B.CreateLoad(B.getInt8Ty(), Str2P, "strcmpload"), CI->getType())); if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"), + CI->getType()); // strcmp(P, "x") -> memcmp(P, "x", 2) uint64_t Len1 = GetStringLength(Str1P); @@ -398,11 +407,12 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { } if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x - return B.CreateNeg( - B.CreateZExt(B.CreateLoad(Str2P, "strcmpload"), CI->getType())); + return B.CreateNeg(B.CreateZExt( + B.CreateLoad(B.getInt8Ty(), Str2P, "strcmpload"), CI->getType())); if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x - return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"), + CI->getType()); uint64_t Len1 = GetStringLength(Str1P); uint64_t Len2 = GetStringLength(Str2P); @@ -591,7 +601,8 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B, // strlen(x) != 0 --> *x != 0 // strlen(x) == 0 --> *x == 0 if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(Src, "strlenfirst"), CI->getType()); + return B.CreateZExt(B.CreateLoad(B.getIntNTy(CharSize), Src, "strlenfirst"), + CI->getType()); return nullptr; } @@ -735,7 +746,8 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilder<> &B) { // strstr("abcd", "bc") -> gep((char*)"abcd", 1) Value *Result = castToCStr(CI->getArgOperand(0), B); - Result = B.CreateConstInBoundsGEP1_64(Result, Offset, "strstr"); + Result = + B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), Result, Offset, "strstr"); return B.CreateBitCast(Result, CI->getType()); } @@ -773,7 +785,8 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { // It would be really nice to reuse switch lowering here but we can't change // the CFG at this point. // - // memchr("\r\n", C, 2) != nullptr -> (C & ((1 << '\r') | (1 << '\n'))) != 0 + // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) + // != 0 // after bounds check. if (!CharC && !Str.empty() && isOnlyUsedInZeroEqualityComparison(CI)) { unsigned char Max = @@ -828,27 +841,20 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "memchr"); } -Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { - Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); - - if (LHS == RHS) // memcmp(s,s,x) -> 0 - return Constant::getNullValue(CI->getType()); - - // Make sure we have a constant length. - ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!LenC) - return nullptr; - - uint64_t Len = LenC->getZExtValue(); +static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, + uint64_t Len, IRBuilder<> &B, + const DataLayout &DL) { if (Len == 0) // memcmp(s1,s2,0) -> 0 return Constant::getNullValue(CI->getType()); // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS if (Len == 1) { - Value *LHSV = B.CreateZExt(B.CreateLoad(castToCStr(LHS, B), "lhsc"), - CI->getType(), "lhsv"); - Value *RHSV = B.CreateZExt(B.CreateLoad(castToCStr(RHS, B), "rhsc"), - CI->getType(), "rhsv"); + Value *LHSV = + B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(LHS, B), "lhsc"), + CI->getType(), "lhsv"); + Value *RHSV = + B.CreateZExt(B.CreateLoad(B.getInt8Ty(), castToCStr(RHS, B), "rhsc"), + CI->getType(), "rhsv"); return B.CreateSub(LHSV, RHSV, "chardiff"); } @@ -878,12 +884,12 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { if (!LHSV) { Type *LHSPtrTy = IntType->getPointerTo(LHS->getType()->getPointerAddressSpace()); - LHSV = B.CreateLoad(B.CreateBitCast(LHS, LHSPtrTy), "lhsv"); + LHSV = B.CreateLoad(IntType, B.CreateBitCast(LHS, LHSPtrTy), "lhsv"); } if (!RHSV) { Type *RHSPtrTy = IntType->getPointerTo(RHS->getType()->getPointerAddressSpace()); - RHSV = B.CreateLoad(B.CreateBitCast(RHS, RHSPtrTy), "rhsv"); + RHSV = B.CreateLoad(IntType, B.CreateBitCast(RHS, RHSPtrTy), "rhsv"); } return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp"); } @@ -907,10 +913,48 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { Ret = 1; return ConstantInt::get(CI->getType(), Ret); } + return nullptr; +} + +// Most simplifications for memcmp also apply to bcmp. +Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI, + IRBuilder<> &B) { + Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); + Value *Size = CI->getArgOperand(2); + + if (LHS == RHS) // memcmp(s,s,x) -> 0 + return Constant::getNullValue(CI->getType()); + + // Handle constant lengths. + if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) + if (Value *Res = optimizeMemCmpConstantSize(CI, LHS, RHS, + LenC->getZExtValue(), B, DL)) + return Res; + + return nullptr; +} + +Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { + if (Value *V = optimizeMemCmpBCmpCommon(CI, B)) + return V; + + // memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0 + // `bcmp` can be more efficient than memcmp because it only has to know that + // there is a difference, not where it is. + if (isOnlyUsedInZeroEqualityComparison(CI) && TLI->has(LibFunc_bcmp)) { + Value *LHS = CI->getArgOperand(0); + Value *RHS = CI->getArgOperand(1); + Value *Size = CI->getArgOperand(2); + return emitBCmp(LHS, RHS, Size, B, DL, TLI); + } return nullptr; } +Value *LibCallSimplifier::optimizeBCmp(CallInst *CI, IRBuilder<> &B) { + return optimizeMemCmpBCmpCommon(CI, B); +} + Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n) B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, @@ -1031,7 +1075,8 @@ static Value *valueHasFloatPrecision(Value *Val) { /// Shrink double -> float functions. static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, bool isBinary, bool isPrecise = false) { - if (!CI->getType()->isDoubleTy()) + Function *CalleeFn = CI->getCalledFunction(); + if (!CI->getType()->isDoubleTy() || !CalleeFn) return nullptr; // If not all the uses of the function are converted to float, then bail out. @@ -1051,15 +1096,16 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, if (!V[0] || (isBinary && !V[1])) return nullptr; + StringRef CalleeNm = CalleeFn->getName(); + AttributeList CalleeAt = CalleeFn->getAttributes(); + bool CalleeIn = CalleeFn->isIntrinsic(); + // If call isn't an intrinsic, check that it isn't within a function with the // same name as the float version of this call, otherwise the result is an // infinite loop. For example, from MinGW-w64: // // float expf(float val) { return (float) exp((double) val); } - Function *CalleeFn = CI->getCalledFunction(); - StringRef CalleeNm = CalleeFn->getName(); - AttributeList CalleeAt = CalleeFn->getAttributes(); - if (CalleeFn && !CalleeFn->isIntrinsic()) { + if (!CalleeIn) { const Function *Fn = CI->getFunction(); StringRef FnName = Fn->getName(); if (FnName.back() == 'f' && @@ -1074,7 +1120,7 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, // g((double) float) -> (double) gf(float) Value *R; - if (CalleeFn->isIntrinsic()) { + if (CalleeIn) { Module *M = CI->getModule(); Intrinsic::ID IID = CalleeFn->getIntrinsicID(); Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); @@ -1132,10 +1178,10 @@ static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func, IRBuilder<> &B) { if (!isa<FPMathOperator>(Call)) return nullptr; - + IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(Call->getFastMathFlags()); - + // TODO: Can this be shared to also handle LLVM intrinsics? Value *X; switch (Func) { @@ -1189,7 +1235,8 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { } /// Use exp{,2}(x * y) for pow(exp{,2}(x), y); -/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x). +/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x); +/// exp2(log2(n) * x) for pow(n, x). Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); @@ -1276,12 +1323,12 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { APFloat BaseR = APFloat(1.0); BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); BaseR = BaseR / *BaseF; - bool IsInteger = BaseF->isInteger(), - IsReciprocal = BaseR.isInteger(); + bool IsInteger = BaseF->isInteger(), IsReciprocal = BaseR.isInteger(); const APFloat *NF = IsReciprocal ? &BaseR : BaseF; APSInt NI(64, false); if ((IsInteger || IsReciprocal) && - !NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) && + NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) == + APFloat::opOK && NI > 1 && NI.isPowerOf2()) { double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0); Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul"); @@ -1301,6 +1348,28 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { return emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, B, Attrs); + // pow(n, x) -> exp2(log2(n) * x) + if (Pow->hasOneUse() && Pow->hasApproxFunc() && Pow->hasNoNaNs() && + Pow->hasNoInfs() && BaseF->isNormal() && !BaseF->isNegative()) { + Value *Log = nullptr; + if (Ty->isFloatTy()) + Log = ConstantFP::get(Ty, std::log2(BaseF->convertToFloat())); + else if (Ty->isDoubleTy()) + Log = ConstantFP::get(Ty, std::log2(BaseF->convertToDouble())); + + if (Log) { + Value *FMul = B.CreateFMul(Log, Expo, "mul"); + if (Pow->doesNotAccessMemory()) { + return B.CreateCall(Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty), + FMul, "exp2"); + } else { + if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l)) + return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, Attrs); + } + } + } return nullptr; } @@ -1364,12 +1433,22 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { return Sqrt; } +static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M, + IRBuilder<> &B) { + Value *Args[] = {Base, Expo}; + Function *F = Intrinsic::getDeclaration(M, Intrinsic::powi, Base->getType()); + return B.CreateCall(F, Args); +} + Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { - Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + Value *Base = Pow->getArgOperand(0); + Value *Expo = Pow->getArgOperand(1); Function *Callee = Pow->getCalledFunction(); StringRef Name = Callee->getName(); Type *Ty = Pow->getType(); + Module *M = Pow->getModule(); Value *Shrunk = nullptr; + bool AllowApprox = Pow->hasApproxFunc(); bool Ignored; // Bail out if simplifying libcalls to pow() is disabled. @@ -1382,8 +1461,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { // Shrink pow() to powf() if the arguments are single precision, // unless the result is expected to be double precision. - if (UnsafeFPShrink && - Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) + if (UnsafeFPShrink && Name == TLI->getName(LibFunc_pow) && + hasFloatVersion(Name)) Shrunk = optimizeBinaryDoubleFP(Pow, B, true); // Evaluate special cases related to the base. @@ -1403,7 +1482,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { // pow(x, 0.0) -> 1.0 if (match(Expo, m_SpecificFP(0.0))) - return ConstantFP::get(Ty, 1.0); + return ConstantFP::get(Ty, 1.0); // pow(x, 1.0) -> x if (match(Expo, m_FPOne())) @@ -1418,7 +1497,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { // pow(x, n) -> x * x * x * ... const APFloat *ExpoF; - if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) { + if (AllowApprox && match(Expo, m_APFloat(ExpoF))) { // We limit to a max of 7 multiplications, thus the maximum exponent is 32. // If the exponent is an integer+0.5 we generate a call to sqrt and an // additional fmul. @@ -1442,9 +1521,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { if (!Expo2.isInteger()) return nullptr; - Sqrt = - getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), - Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI); + Sqrt = getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), M, B, TLI); } // We will memoize intermediate products of the Addition Chain. @@ -1467,6 +1545,29 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { return FMul; } + + APSInt IntExpo(32, /*isUnsigned=*/false); + // powf(x, n) -> powi(x, n) if n is a constant signed integer value + if (ExpoF->isInteger() && + ExpoF->convertToInteger(IntExpo, APFloat::rmTowardZero, &Ignored) == + APFloat::opOK) { + return createPowWithIntegerExponent( + Base, ConstantInt::get(B.getInt32Ty(), IntExpo), M, B); + } + } + + // powf(x, itofp(y)) -> powi(x, y) + if (AllowApprox && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo))) { + Value *IntExpo = cast<Instruction>(Expo)->getOperand(0); + Value *NewExpo = nullptr; + unsigned BitWidth = IntExpo->getType()->getPrimitiveSizeInBits(); + if (isa<SIToFPInst>(Expo) && BitWidth == 32) + NewExpo = IntExpo; + else if (BitWidth < 32) + NewExpo = isa<SIToFPInst>(Expo) ? B.CreateSExt(IntExpo, B.getInt32Ty()) + : B.CreateZExt(IntExpo, B.getInt32Ty()); + if (NewExpo) + return createPowWithIntegerExponent(Base, NewExpo, M, B); } return Shrunk; @@ -1504,9 +1605,8 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { One = ConstantExpr::getFPExtend(One, Op->getType()); Module *M = CI->getModule(); - Value *NewCallee = - M->getOrInsertFunction(TLI->getName(LdExp), Op->getType(), - Op->getType(), B.getInt32Ty()); + FunctionCallee NewCallee = M->getOrInsertFunction( + TLI->getName(LdExp), Op->getType(), Op->getType(), B.getInt32Ty()); CallInst *CI = B.CreateCall(NewCallee, {One, LdExpArg}); if (const Function *F = dyn_cast<Function>(Callee->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1518,40 +1618,30 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); // If we can shrink the call to a float function rather than a double // function, do that first. + Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(Name)) if (Value *Ret = optimizeBinaryDoubleFP(CI, B)) return Ret; + // The LLVM intrinsics minnum/maxnum correspond to fmin/fmax. Canonicalize to + // the intrinsics for improved optimization (for example, vectorization). + // No-signed-zeros is implied by the definitions of fmax/fmin themselves. + // From the C standard draft WG14/N1256: + // "Ideally, fmax would be sensitive to the sign of zero, for example + // fmax(-0.0, +0.0) would return +0; however, implementation in software + // might be impractical." IRBuilder<>::FastMathFlagGuard Guard(B); - FastMathFlags FMF; - if (CI->isFast()) { - // If the call is 'fast', then anything we create here will also be 'fast'. - FMF.setFast(); - } else { - // At a minimum, no-nans-fp-math must be true. - if (!CI->hasNoNaNs()) - return nullptr; - // No-signed-zeros is implied by the definitions of fmax/fmin themselves: - // "Ideally, fmax would be sensitive to the sign of zero, for example - // fmax(-0. 0, +0. 0) would return +0; however, implementation in software - // might be impractical." - FMF.setNoSignedZeros(); - FMF.setNoNaNs(); - } + FastMathFlags FMF = CI->getFastMathFlags(); + FMF.setNoSignedZeros(); B.setFastMathFlags(FMF); - // We have a relaxed floating-point environment. We can ignore NaN-handling - // and transform to a compare and select. We do not have to consider errno or - // exceptions, because fmin/fmax do not have those. - Value *Op0 = CI->getArgOperand(0); - Value *Op1 = CI->getArgOperand(1); - Value *Cmp = Callee->getName().startswith("fmin") ? - B.CreateFCmpOLT(Op0, Op1) : B.CreateFCmpOGT(Op0, Op1); - return B.CreateSelect(Cmp, Op0, Op1); + Intrinsic::ID IID = Callee->getName().startswith("fmin") ? Intrinsic::minnum + : Intrinsic::maxnum; + Function *F = Intrinsic::getDeclaration(CI->getModule(), IID, CI->getType()); + return B.CreateCall(F, { CI->getArgOperand(0), CI->getArgOperand(1) }); } Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { @@ -1654,13 +1744,13 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { // replace it with the fabs of that factor. Module *M = Callee->getParent(); Type *ArgType = I->getType(); - Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); + Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); if (OtherOp) { // If we found a non-repeated factor, we still need to get its square // root. We then multiply that by the value that was simplified out // of the square root calculation. - Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); return B.CreateFMul(FabsCall, SqrtCall); } @@ -1728,8 +1818,8 @@ static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg, } Module *M = OrigCallee->getParent(); - Value *Callee = M->getOrInsertFunction(Name, OrigCallee->getAttributes(), - ResTy, ArgTy); + FunctionCallee Callee = + M->getOrInsertFunction(Name, OrigCallee->getAttributes(), ResTy, ArgTy); if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { // If the argument is an instruction, it must dominate all uses so put our @@ -1840,8 +1930,8 @@ Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); - Value *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), - Intrinsic::cttz, ArgType); + Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), + Intrinsic::cttz, ArgType); Value *V = B.CreateCall(F, {Op, B.getTrue()}, "cttz"); V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); V = B.CreateIntCast(V, B.getInt32Ty(), false); @@ -1854,8 +1944,8 @@ Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilder<> &B) { // fls(x) -> (i32)(sizeInBits(x) - llvm.ctlz(x, false)) Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); - Value *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), - Intrinsic::ctlz, ArgType); + Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), + Intrinsic::ctlz, ArgType); Value *V = B.CreateCall(F, {Op, B.getFalse()}, "ctlz"); V = B.CreateSub(ConstantInt::get(V->getType(), ArgType->getIntegerBitWidth()), V); @@ -2026,13 +2116,27 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilder<> &B) { // arguments. if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *IPrintFFn = + FunctionCallee IPrintFFn = M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(IPrintFFn); B.Insert(New); return New; } + + // printf(format, ...) -> __small_printf(format, ...) if no 128-bit floating point + // arguments. + if (TLI->has(LibFunc_small_printf) && !callHasFP128Argument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + auto SmallPrintFFn = + M->getOrInsertFunction(TLI->getName(LibFunc_small_printf), + FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SmallPrintFFn); + B.Insert(New); + return New; + } + return nullptr; } @@ -2077,7 +2181,8 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { } if (FormatStr[1] == 's') { - // sprintf(dest, "%s", str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) + // sprintf(dest, "%s", str) -> llvm.memcpy(align 1 dest, align 1 str, + // strlen(str)+1) if (!CI->getArgOperand(2)->getType()->isPointerTy()) return nullptr; @@ -2105,13 +2210,27 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { // point arguments. if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *SIPrintFFn = + FunctionCallee SIPrintFFn = M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SIPrintFFn); B.Insert(New); return New; } + + // sprintf(str, format, ...) -> __small_sprintf(str, format, ...) if no 128-bit + // floating point arguments. + if (TLI->has(LibFunc_small_sprintf) && !callHasFP128Argument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + auto SmallSPrintFFn = + M->getOrInsertFunction(TLI->getName(LibFunc_small_sprintf), + FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SmallSPrintFFn); + B.Insert(New); + return New; + } + return nullptr; } @@ -2140,7 +2259,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) { else if (N < FormatStr.size() + 1) return nullptr; - // sprintf(str, size, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, + // snprintf(dst, size, fmt) -> llvm.memcpy(align 1 dst, align 1 fmt, // strlen(fmt)+1) B.CreateMemCpy( CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, @@ -2262,13 +2381,27 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilder<> &B) { // floating point arguments. if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) { Module *M = B.GetInsertBlock()->getParent()->getParent(); - Constant *FIPrintFFn = + FunctionCallee FIPrintFFn = M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(FIPrintFFn); B.Insert(New); return New; } + + // fprintf(stream, format, ...) -> __small_fprintf(stream, format, ...) if no + // 128-bit floating point arguments. + if (TLI->has(LibFunc_small_fprintf) && !callHasFP128Argument(CI)) { + Module *M = B.GetInsertBlock()->getParent()->getParent(); + auto SmallFPrintFFn = + M->getOrInsertFunction(TLI->getName(LibFunc_small_fprintf), + FT, Callee->getAttributes()); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(SmallFPrintFFn); + B.Insert(New); + return New; + } + return nullptr; } @@ -2288,7 +2421,8 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { // If this is writing one byte, turn it into fputc. // This optimisation is only valid, if the return value is unused. if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(castToCStr(CI->getArgOperand(0), B), "char"); + Value *Char = B.CreateLoad(B.getInt8Ty(), + castToCStr(CI->getArgOperand(0), B), "char"); Value *NewCI = emitFPutC(Char, CI->getArgOperand(3), B, TLI); return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; } @@ -2307,7 +2441,9 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { // Don't rewrite fputs to fwrite when optimising for size because fwrite // requires more arguments and thus extra MOVs are required. - if (CI->getFunction()->optForSize()) + bool OptForSize = CI->getFunction()->hasOptSize() || + llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); + if (OptForSize) return nullptr; // Check if has any use @@ -2320,7 +2456,7 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { return nullptr; } - // fputs(s,F) --> fwrite(s,1,strlen(s),F) + // fputs(s,F) --> fwrite(s,strlen(s),1,F) uint64_t Len = GetStringLength(CI->getArgOperand(0)); if (!Len) return nullptr; @@ -2367,18 +2503,14 @@ Value *LibCallSimplifier::optimizeFRead(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { - // Check for a constant string. - StringRef Str; - if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + if (!CI->use_empty()) return nullptr; - if (Str.empty() && CI->use_empty()) { - // puts("") -> putchar('\n') - Value *Res = emitPutChar(B.getInt32('\n'), B, TLI); - if (CI->use_empty() || !Res) - return Res; - return B.CreateIntCast(Res, CI->getType(), true); - } + // Check for a constant string. + // puts("") -> putchar('\n') + StringRef Str; + if (getConstantStringInfo(CI->getArgOperand(0), Str) && Str.empty()) + return emitPutChar(B.getInt32('\n'), B, TLI); return nullptr; } @@ -2441,6 +2573,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrStr(CI, Builder); case LibFunc_memchr: return optimizeMemChr(CI, Builder); + case LibFunc_bcmp: + return optimizeBCmp(CI, Builder); case LibFunc_memcmp: return optimizeMemCmp(CI, Builder); case LibFunc_memcpy: @@ -2686,9 +2820,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { LibCallSimplifier::LibCallSimplifier( const DataLayout &DL, const TargetLibraryInfo *TLI, OptimizationRemarkEmitter &ORE, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, function_ref<void(Instruction *, Value *)> Replacer, function_ref<void(Instruction *)> Eraser) - : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), + : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), BFI(BFI), PSI(PSI), UnsafeFPShrink(false), Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { @@ -2735,12 +2870,23 @@ void LibCallSimplifier::eraseFromParent(Instruction *I) { // Fortified Library Call Optimizations //===----------------------------------------------------------------------===// -bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, - unsigned ObjSizeOp, - unsigned SizeOp, - bool isString) { - if (CI->getArgOperand(ObjSizeOp) == CI->getArgOperand(SizeOp)) +bool +FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, + unsigned ObjSizeOp, + Optional<unsigned> SizeOp, + Optional<unsigned> StrOp, + Optional<unsigned> FlagOp) { + // If this function takes a flag argument, the implementation may use it to + // perform extra checks. Don't fold into the non-checking variant. + if (FlagOp) { + ConstantInt *Flag = dyn_cast<ConstantInt>(CI->getArgOperand(*FlagOp)); + if (!Flag || !Flag->isZero()) + return false; + } + + if (SizeOp && CI->getArgOperand(ObjSizeOp) == CI->getArgOperand(*SizeOp)) return true; + if (ConstantInt *ObjSizeCI = dyn_cast<ConstantInt>(CI->getArgOperand(ObjSizeOp))) { if (ObjSizeCI->isMinusOne()) @@ -2748,23 +2894,27 @@ bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, // If the object size wasn't -1 (unknown), bail out if we were asked to. if (OnlyLowerUnknownSize) return false; - if (isString) { - uint64_t Len = GetStringLength(CI->getArgOperand(SizeOp)); + if (StrOp) { + uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp)); // If the length is 0 we don't know how long it is and so we can't // remove the check. if (Len == 0) return false; return ObjSizeCI->getZExtValue() >= Len; } - if (ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getArgOperand(SizeOp))) - return ObjSizeCI->getZExtValue() >= SizeCI->getZExtValue(); + + if (SizeOp) { + if (ConstantInt *SizeCI = + dyn_cast<ConstantInt>(CI->getArgOperand(*SizeOp))) + return ObjSizeCI->getZExtValue() >= SizeCI->getZExtValue(); + } } return false; } Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { - if (isFortifiedCallFoldable(CI, 3, 2, false)) { + if (isFortifiedCallFoldable(CI, 3, 2)) { B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2)); return CI->getArgOperand(0); @@ -2774,7 +2924,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { - if (isFortifiedCallFoldable(CI, 3, 2, false)) { + if (isFortifiedCallFoldable(CI, 3, 2)) { B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, CI->getArgOperand(2)); return CI->getArgOperand(0); @@ -2786,7 +2936,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) { // TODO: Try foldMallocMemset() here. - if (isFortifiedCallFoldable(CI, 3, 2, false)) { + if (isFortifiedCallFoldable(CI, 3, 2)) { Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); return CI->getArgOperand(0); @@ -2797,8 +2947,6 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, IRBuilder<> &B, LibFunc Func) { - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); const DataLayout &DL = CI->getModule()->getDataLayout(); Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1), *ObjSize = CI->getArgOperand(2); @@ -2814,8 +2962,12 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // st[rp]cpy_chk call which may fail at runtime if the size is too long. // TODO: It might be nice to get a maximum length out of the possible // string lengths for varying. - if (isFortifiedCallFoldable(CI, 2, 1, true)) - return emitStrCpy(Dst, Src, B, TLI, Name.substr(2, 6)); + if (isFortifiedCallFoldable(CI, 2, None, 1)) { + if (Func == LibFunc_strcpy_chk) + return emitStrCpy(Dst, Src, B, TLI); + else + return emitStpCpy(Dst, Src, B, TLI); + } if (OnlyLowerUnknownSize) return nullptr; @@ -2838,13 +2990,99 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI, IRBuilder<> &B, LibFunc Func) { - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); - if (isFortifiedCallFoldable(CI, 3, 2, false)) { - Value *Ret = emitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), B, TLI, Name.substr(2, 7)); - return Ret; + if (isFortifiedCallFoldable(CI, 3, 2)) { + if (Func == LibFunc_strncpy_chk) + return emitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); + else + return emitStpNCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); } + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeMemCCpyChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 4, 3)) + return emitMemCCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), CI->getArgOperand(3), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeSNPrintfChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 3, 1, None, 2)) { + SmallVector<Value *, 8> VariadicArgs(CI->arg_begin() + 5, CI->arg_end()); + return emitSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(4), VariadicArgs, B, TLI); + } + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeSPrintfChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 2, None, None, 1)) { + SmallVector<Value *, 8> VariadicArgs(CI->arg_begin() + 4, CI->arg_end()); + return emitSPrintf(CI->getArgOperand(0), CI->getArgOperand(3), VariadicArgs, + B, TLI); + } + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrCatChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 2)) + return emitStrCat(CI->getArgOperand(0), CI->getArgOperand(1), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrLCat(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 3)) + return emitStrLCat(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrNCatChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 3)) + return emitStrNCat(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeStrLCpyChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 3)) + return emitStrLCpy(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeVSNPrintfChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 3, 1, None, 2)) + return emitVSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(4), CI->getArgOperand(5), B, TLI); + + return nullptr; +} + +Value *FortifiedLibCallSimplifier::optimizeVSPrintfChk(CallInst *CI, + IRBuilder<> &B) { + if (isFortifiedCallFoldable(CI, 2, None, None, 1)) + return emitVSPrintf(CI->getArgOperand(0), CI->getArgOperand(3), + CI->getArgOperand(4), B, TLI); + return nullptr; } @@ -2892,6 +3130,24 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc_stpncpy_chk: case LibFunc_strncpy_chk: return optimizeStrpNCpyChk(CI, Builder, Func); + case LibFunc_memccpy_chk: + return optimizeMemCCpyChk(CI, Builder); + case LibFunc_snprintf_chk: + return optimizeSNPrintfChk(CI, Builder); + case LibFunc_sprintf_chk: + return optimizeSPrintfChk(CI, Builder); + case LibFunc_strcat_chk: + return optimizeStrCatChk(CI, Builder); + case LibFunc_strlcat_chk: + return optimizeStrLCat(CI, Builder); + case LibFunc_strncat_chk: + return optimizeStrNCatChk(CI, Builder); + case LibFunc_strlcpy_chk: + return optimizeStrLCpyChk(CI, Builder); + case LibFunc_vsnprintf_chk: + return optimizeVSNPrintfChk(CI, Builder); + case LibFunc_vsprintf_chk: + return optimizeVSPrintfChk(CI, Builder); default: break; } diff --git a/lib/Transforms/Utils/SizeOpts.cpp b/lib/Transforms/Utils/SizeOpts.cpp new file mode 100644 index 000000000000..1519751197d2 --- /dev/null +++ b/lib/Transforms/Utils/SizeOpts.cpp @@ -0,0 +1,37 @@ +//===-- SizeOpts.cpp - code size optimization related code ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some shared code size optimization related code. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/SizeOpts.h" +using namespace llvm; + +static cl::opt<bool> ProfileGuidedSizeOpt( + "pgso", cl::Hidden, cl::init(true), + cl::desc("Enable the profile guided size optimization. ")); + +bool llvm::shouldOptimizeForSize(Function *F, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + assert(F); + if (!PSI || !BFI || !PSI->hasProfileSummary()) + return false; + return ProfileGuidedSizeOpt && PSI->isFunctionColdInCallGraph(F, *BFI); +} + +bool llvm::shouldOptimizeForSize(BasicBlock *BB, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI) { + assert(BB); + if (!PSI || !BFI || !PSI->hasProfileSummary()) + return false; + return ProfileGuidedSizeOpt && PSI->isColdBlock(BB, BFI); +} diff --git a/lib/Transforms/Utils/SplitModule.cpp b/lib/Transforms/Utils/SplitModule.cpp index 5db4d2e4df9d..e2c387cb8983 100644 --- a/lib/Transforms/Utils/SplitModule.cpp +++ b/lib/Transforms/Utils/SplitModule.cpp @@ -1,9 +1,8 @@ //===- SplitModule.cpp - Split a module into partitions -------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/StripGCRelocates.cpp b/lib/Transforms/Utils/StripGCRelocates.cpp index ac0b519f4a77..50844cf9d1c5 100644 --- a/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/lib/Transforms/Utils/StripGCRelocates.cpp @@ -1,9 +1,8 @@ //===- StripGCRelocates.cpp - Remove gc.relocates inserted by RewriteStatePoints===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp index 8956a089a99c..97a4533fabe5 100644 --- a/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp +++ b/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp @@ -1,9 +1,8 @@ //===- StripNonLineTableDebugInfo.cpp -- Strip parts of Debug Info --------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp index fd0da79487f1..456724779b43 100644 --- a/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -1,9 +1,8 @@ //===- SymbolRewriter.cpp - Symbol Rewriter -------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index d49b26472548..7f7bdf8a3d6d 100644 --- a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -1,9 +1,8 @@ //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp index 95416de07439..5272ab6e95d5 100644 --- a/lib/Transforms/Utils/Utils.cpp +++ b/lib/Transforms/Utils/Utils.cpp @@ -1,9 +1,8 @@ //===-- Utils.cpp - TransformUtils Infrastructure -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -55,3 +54,6 @@ void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createPromoteMemoryToRegisterPass()); } +void LLVMAddAddDiscriminatorsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAddDiscriminatorsPass()); +} diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp index 948d9bd5baad..a77bf50fe10b 100644 --- a/lib/Transforms/Utils/VNCoercion.cpp +++ b/lib/Transforms/Utils/VNCoercion.cpp @@ -14,13 +14,17 @@ namespace VNCoercion { /// Return true if coerceAvailableValueToLoadType will succeed. bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, const DataLayout &DL) { + Type *StoredTy = StoredVal->getType(); + if (StoredTy == LoadTy) + return true; + // If the loaded or stored value is an first class array or struct, don't try // to transform them. We need to be able to bitcast to integer. - if (LoadTy->isStructTy() || LoadTy->isArrayTy() || - StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy()) + if (LoadTy->isStructTy() || LoadTy->isArrayTy() || StoredTy->isStructTy() || + StoredTy->isArrayTy()) return false; - uint64_t StoreSize = DL.getTypeSizeInBits(StoredVal->getType()); + uint64_t StoreSize = DL.getTypeSizeInBits(StoredTy); // The store size must be byte-aligned to support future type casts. if (llvm::alignTo(StoreSize, 8) != StoreSize) @@ -31,10 +35,16 @@ bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, return false; // Don't coerce non-integral pointers to integers or vice versa. - if (DL.isNonIntegralPointerType(StoredVal->getType()) != - DL.isNonIntegralPointerType(LoadTy)) + if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()) != + DL.isNonIntegralPointerType(LoadTy->getScalarType())) { + // As a special case, allow coercion of memset used to initialize + // an array w/null. Despite non-integral pointers not generally having a + // specific bit pattern, we do assume null is zero. + if (auto *CI = dyn_cast<Constant>(StoredVal)) + return CI->isNullValue(); return false; - + } + return true; } @@ -207,11 +217,22 @@ static int analyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, /// memdep query of a load that ends up being a clobbering store. int analyzeLoadFromClobberingStore(Type *LoadTy, Value *LoadPtr, StoreInst *DepSI, const DataLayout &DL) { + auto *StoredVal = DepSI->getValueOperand(); + // Cannot handle reading from store of first-class aggregate yet. - if (DepSI->getValueOperand()->getType()->isStructTy() || - DepSI->getValueOperand()->getType()->isArrayTy()) + if (StoredVal->getType()->isStructTy() || + StoredVal->getType()->isArrayTy()) return -1; + // Don't coerce non-integral pointers to integers or vice versa. + if (DL.isNonIntegralPointerType(StoredVal->getType()->getScalarType()) != + DL.isNonIntegralPointerType(LoadTy->getScalarType())) { + // Allow casts of zero values to null as a special case + auto *CI = dyn_cast<Constant>(StoredVal); + if (!CI || !CI->isNullValue()) + return -1; + } + Value *StorePtr = DepSI->getPointerOperand(); uint64_t StoreSize = DL.getTypeSizeInBits(DepSI->getValueOperand()->getType()); @@ -228,6 +249,11 @@ int analyzeLoadFromClobberingLoad(Type *LoadTy, Value *LoadPtr, LoadInst *DepLI, if (DepLI->getType()->isStructTy() || DepLI->getType()->isArrayTy()) return -1; + // Don't coerce non-integral pointers to integers or vice versa. + if (DL.isNonIntegralPointerType(DepLI->getType()->getScalarType()) != + DL.isNonIntegralPointerType(LoadTy->getScalarType())) + return -1; + Value *DepPtr = DepLI->getPointerOperand(); uint64_t DepSize = DL.getTypeSizeInBits(DepLI->getType()); int R = analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, DepPtr, DepSize, DL); @@ -264,9 +290,15 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, // If this is memset, we just need to see if the offset is valid in the size // of the memset.. - if (MI->getIntrinsicID() == Intrinsic::memset) + if (MI->getIntrinsicID() == Intrinsic::memset) { + if (DL.isNonIntegralPointerType(LoadTy->getScalarType())) { + auto *CI = dyn_cast<ConstantInt>(cast<MemSetInst>(MI)->getValue()); + if (!CI || !CI->isZero()) + return -1; + } return analyzeLoadFromClobberingWrite(LoadTy, LoadPtr, MI->getDest(), MemSizeInBits, DL); + } // If we have a memcpy/memmove, the only case we can handle is if this is a // copy from constant memory. In that case, we can read directly from the @@ -278,7 +310,7 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, return -1; GlobalVariable *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(Src, DL)); - if (!GV || !GV->isConstant()) + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) return -1; // See if the access is within the bounds of the transfer. @@ -287,6 +319,12 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, if (Offset == -1) return Offset; + // Don't coerce non-integral pointers to integers or vice versa, and the + // memtransfer is implicitly a raw byte code + if (DL.isNonIntegralPointerType(LoadTy->getScalarType())) + // TODO: Can allow nullptrs from constant zeros + return -1; + unsigned AS = Src->getType()->getPointerAddressSpace(); // Otherwise, see if we can constant fold a load from the constant with the // offset applied as appropriate. @@ -386,12 +424,12 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, // memdep queries will find the new load. We can't easily remove the old // load completely because it is already in the value numbering table. IRBuilder<> Builder(SrcVal->getParent(), ++BasicBlock::iterator(SrcVal)); - Type *DestPTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8); - DestPTy = - PointerType::get(DestPTy, PtrVal->getType()->getPointerAddressSpace()); + Type *DestTy = IntegerType::get(LoadTy->getContext(), NewLoadSize * 8); + Type *DestPTy = + PointerType::get(DestTy, PtrVal->getType()->getPointerAddressSpace()); Builder.SetCurrentDebugLocation(SrcVal->getDebugLoc()); PtrVal = Builder.CreateBitCast(PtrVal, DestPTy); - LoadInst *NewLoad = Builder.CreateLoad(PtrVal); + LoadInst *NewLoad = Builder.CreateLoad(DestTy, PtrVal); NewLoad->takeName(SrcVal); NewLoad->setAlignment(SrcVal->getAlignment()); diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 55fff3f3872a..fbc3407c301f 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -1,9 +1,8 @@ //===- ValueMapper.cpp - Interface shared by lib/Transforms/Utils ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -914,6 +913,21 @@ void Mapper::remapInstruction(Instruction *I) { Tys.push_back(TypeMapper->remapType(Ty)); CS.mutateFunctionType(FunctionType::get( TypeMapper->remapType(I->getType()), Tys, FTy->isVarArg())); + + LLVMContext &C = CS->getContext(); + AttributeList Attrs = CS.getAttributes(); + for (unsigned i = 0; i < Attrs.getNumAttrSets(); ++i) { + if (Attrs.hasAttribute(i, Attribute::ByVal)) { + Type *Ty = Attrs.getAttribute(i, Attribute::ByVal).getValueAsType(); + if (!Ty) + continue; + + Attrs = Attrs.removeAttribute(C, i, Attribute::ByVal); + Attrs = Attrs.addAttribute( + C, i, Attribute::getWithByValType(C, TypeMapper->remapType(Ty))); + } + } + CS.setAttributes(Attrs); return; } if (auto *AI = dyn_cast<AllocaInst>(I)) diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 9ff18328c219..4273080ddd91 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -1,9 +1,8 @@ //===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -927,7 +926,7 @@ bool Vectorizer::vectorizeStoreChain( StoreInst *S0 = cast<StoreInst>(Chain[0]); // If the vector has an int element, default to int for the whole store. - Type *StoreTy; + Type *StoreTy = nullptr; for (Instruction *I : Chain) { StoreTy = cast<StoreInst>(I)->getValueOperand()->getType(); if (StoreTy->isIntOrIntVectorTy()) @@ -939,6 +938,7 @@ bool Vectorizer::vectorizeStoreChain( break; } } + assert(StoreTy && "Failed to find store type"); unsigned Sz = DL.getTypeSizeInBits(StoreTy); unsigned AS = S0->getPointerAddressSpace(); @@ -1152,13 +1152,8 @@ bool Vectorizer::vectorizeLoadChain( vectorizeLoadChain(Chains.second, InstructionsProcessed); } - unsigned NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), - StackAdjustedAlignment, - DL, L0, nullptr, &DT); - if (NewAlign != 0) - Alignment = NewAlign; - - Alignment = NewAlign; + Alignment = getOrEnforceKnownAlignment( + L0->getPointerOperand(), StackAdjustedAlignment, DL, L0, nullptr, &DT); } if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { @@ -1182,7 +1177,7 @@ bool Vectorizer::vectorizeLoadChain( Value *Bitcast = Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - LoadInst *LI = Builder.CreateAlignedLoad(Bitcast, Alignment); + LoadInst *LI = Builder.CreateAlignedLoad(VecTy, Bitcast, Alignment); propagateMetadata(LI, Chain); if (VecLoadTy) { diff --git a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index b44fe5a52a2f..6ef8dc2d3cd7 100644 --- a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1,9 +1,8 @@ //===- LoopVectorizationLegality.cpp --------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -23,6 +22,8 @@ using namespace llvm; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME +extern cl::opt<bool> EnableVPlanPredication; + static cl::opt<bool> EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, cl::desc("Enable if-conversion during vectorization.")); @@ -46,6 +47,18 @@ static const unsigned MaxInterleaveFactor = 16; namespace llvm { +#ifndef NDEBUG +static void debugVectorizationFailure(const StringRef DebugMsg, + Instruction *I) { + dbgs() << "LV: Not vectorizing: " << DebugMsg; + if (I != nullptr) + dbgs() << " " << *I; + else + dbgs() << '.'; + dbgs() << '\n'; +} +#endif + OptimizationRemarkAnalysis createLVMissedAnalysis(const char *PassName, StringRef RemarkName, Loop *TheLoop, @@ -103,6 +116,25 @@ LoopVectorizeHints::LoopVectorizeHints(const Loop *L, << "LV: Interleaving disabled by the pass manager\n"); } +void LoopVectorizeHints::setAlreadyVectorized() { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + + MDNode *IsVectorizedMD = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.isvectorized"), + ConstantAsMetadata::get(ConstantInt::get(Context, APInt(32, 1)))}); + MDNode *LoopID = TheLoop->getLoopID(); + MDNode *NewLoopID = + makePostTransformationMetadata(Context, LoopID, + {Twine(Prefix(), "vectorize.").str(), + Twine(Prefix(), "interleave.").str()}, + {IsVectorizedMD}); + TheLoop->setLoopID(NewLoopID); + + // Update internal cache. + IsVectorized.Value = 1; +} + bool LoopVectorizeHints::allowVectorization( Function *F, Loop *L, bool VectorizeOnlyWhenForced) const { if (getForce() == LoopVectorizeHints::FK_Disabled) { @@ -230,57 +262,6 @@ void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { } } -MDNode *LoopVectorizeHints::createHintMetadata(StringRef Name, - unsigned V) const { - LLVMContext &Context = TheLoop->getHeader()->getContext(); - Metadata *MDs[] = { - MDString::get(Context, Name), - ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; - return MDNode::get(Context, MDs); -} - -bool LoopVectorizeHints::matchesHintMetadataName(MDNode *Node, - ArrayRef<Hint> HintTypes) { - MDString *Name = dyn_cast<MDString>(Node->getOperand(0)); - if (!Name) - return false; - - for (auto H : HintTypes) - if (Name->getString().endswith(H.Name)) - return true; - return false; -} - -void LoopVectorizeHints::writeHintsToMetadata(ArrayRef<Hint> HintTypes) { - if (HintTypes.empty()) - return; - - // Reserve the first element to LoopID (see below). - SmallVector<Metadata *, 4> MDs(1); - // If the loop already has metadata, then ignore the existing operands. - MDNode *LoopID = TheLoop->getLoopID(); - if (LoopID) { - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); - // If node in update list, ignore old value. - if (!matchesHintMetadataName(Node, HintTypes)) - MDs.push_back(Node); - } - } - - // Now, add the missing hints. - for (auto H : HintTypes) - MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); - - // Replace current metadata node with new one. - LLVMContext &Context = TheLoop->getHeader()->getContext(); - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - - TheLoop->setLoopID(NewLoopID); -} - bool LoopVectorizationRequirements::doesNotMeet( Function *F, Loop *L, const LoopVectorizeHints &Hints) { const char *PassName = Hints.vectorizeAnalysisPassName(); @@ -464,6 +445,14 @@ bool LoopVectorizationLegality::isUniform(Value *V) { return LAI->isUniform(V); } +void LoopVectorizationLegality::reportVectorizationFailure( + const StringRef DebugMsg, const StringRef OREMsg, + const StringRef ORETag, Instruction *I) const { + LLVM_DEBUG(debugVectorizationFailure(DebugMsg, I)); + ORE->emit(createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(), + ORETag, TheLoop, I) << OREMsg); +} + bool LoopVectorizationLegality::canVectorizeOuterLoop() { assert(!TheLoop->empty() && "We are not vectorizing an outer loop."); // Store the result and return it at the end instead of exiting early, in case @@ -476,9 +465,9 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { // not supported yet. auto *Br = dyn_cast<BranchInst>(BB->getTerminator()); if (!Br) { - LLVM_DEBUG(dbgs() << "LV: Unsupported basic block terminator.\n"); - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("Unsupported basic block terminator", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -488,13 +477,16 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { // Check whether the BranchInst is a supported one. Only unconditional // branches, conditional branches with an outer loop invariant condition or // backedges are supported. - if (Br && Br->isConditional() && + // FIXME: We skip these checks when VPlan predication is enabled as we + // want to allow divergent branches. This whole check will be removed + // once VPlan predication is on by default. + if (!EnableVPlanPredication && Br && Br->isConditional() && !TheLoop->isLoopInvariant(Br->getCondition()) && !LI->isLoopHeader(Br->getSuccessor(0)) && !LI->isLoopHeader(Br->getSuccessor(1))) { - LLVM_DEBUG(dbgs() << "LV: Unsupported conditional branch.\n"); - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("Unsupported conditional branch", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -506,11 +498,9 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { // simple outer loops scenarios with uniform nested loops. if (!isUniformLoopNest(TheLoop /*loop nest*/, TheLoop /*context outer loop*/)) { - LLVM_DEBUG( - dbgs() - << "LV: Not vectorizing: Outer loop contains divergent loops.\n"); - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("Outer loop contains divergent loops", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -519,10 +509,9 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { // Check whether we are able to set up outer loop induction. if (!setupOuterLoopInductions()) { - LLVM_DEBUG( - dbgs() << "LV: Not vectorizing: Unsupported outer loop Phi(s).\n"); - ORE->emit(createMissedAnalysis("UnsupportedPhi") - << "Unsupported outer loop Phi(s)"); + reportVectorizationFailure("Unsupported outer loop Phi(s)", + "Unsupported outer loop Phi(s)", + "UnsupportedPhi"); if (DoExtraAnalysis) Result = false; else @@ -627,9 +616,9 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Check that this PHI type is allowed. if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && !PhiTy->isPointerTy()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) - << "loop control flow is not understood by vectorizer"); - LLVM_DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n"); + reportVectorizationFailure("Found a non-int non-pointer PHI", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); return false; } @@ -647,9 +636,9 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // We only allow if-converted PHIs with exactly two incoming values. if (Phi->getNumIncomingValues() != 2) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) - << "control flow not understood by vectorizer"); - LLVM_DEBUG(dbgs() << "LV: Found an invalid PHI.\n"); + reportVectorizationFailure("Found an invalid PHI", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood", Phi); return false; } @@ -698,10 +687,10 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } - ORE->emit(createMissedAnalysis("NonReductionValueUsedOutsideLoop", Phi) - << "value that could not be identified as " - "reduction is used outside the loop"); - LLVM_DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); + reportVectorizationFailure("Found an unidentified PHI", + "value that could not be identified as " + "reduction is used outside the loop", + "NonReductionValueUsedOutsideLoop", Phi); return false; } // end of PHI handling @@ -728,31 +717,33 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // but it's hard to provide meaningful yet generic advice. // Also, should this be guarded by allowExtraAnalysis() and/or be part // of the returned info from isFunctionVectorizable()? - ORE->emit(createMissedAnalysis("CantVectorizeLibcall", CI) - << "library call cannot be vectorized. " - "Try compiling with -fno-math-errno, -ffast-math, " - "or similar flags"); + reportVectorizationFailure("Found a non-intrinsic callsite", + "library call cannot be vectorized. " + "Try compiling with -fno-math-errno, -ffast-math, " + "or similar flags", + "CantVectorizeLibcall", CI); } else { - ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) - << "call instruction cannot be vectorized"); + reportVectorizationFailure("Found a non-intrinsic callsite", + "call instruction cannot be vectorized", + "CantVectorizeLibcall", CI); } - LLVM_DEBUG( - dbgs() << "LV: Found a non-intrinsic callsite.\n"); return false; } - // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the - // second argument is the same (i.e. loop invariant) - if (CI && hasVectorInstrinsicScalarOpd( - getVectorIntrinsicIDForCall(CI, TLI), 1)) { + // Some intrinsics have scalar arguments and should be same in order for + // them to be vectorized (i.e. loop invariant). + if (CI) { auto *SE = PSE.getSE(); - if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { - ORE->emit(createMissedAnalysis("CantVectorizeIntrinsic", CI) - << "intrinsic instruction cannot be vectorized"); - LLVM_DEBUG(dbgs() - << "LV: Found unvectorizable intrinsic " << *CI << "\n"); - return false; - } + Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI); + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) + if (hasVectorInstrinsicScalarOpd(IntrinID, i)) { + if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(i)), TheLoop)) { + reportVectorizationFailure("Found unvectorizable intrinsic", + "intrinsic instruction cannot be vectorized", + "CantVectorizeIntrinsic", CI); + return false; + } + } } // Check that the instruction return type is vectorizable. @@ -760,9 +751,9 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if ((!VectorType::isValidElementType(I.getType()) && !I.getType()->isVoidTy()) || isa<ExtractElementInst>(I)) { - ORE->emit(createMissedAnalysis("CantVectorizeInstructionReturnType", &I) - << "instruction return type cannot be vectorized"); - LLVM_DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); + reportVectorizationFailure("Found unvectorizable type", + "instruction return type cannot be vectorized", + "CantVectorizeInstructionReturnType", &I); return false; } @@ -770,11 +761,44 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (auto *ST = dyn_cast<StoreInst>(&I)) { Type *T = ST->getValueOperand()->getType(); if (!VectorType::isValidElementType(T)) { - ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) - << "store instruction cannot be vectorized"); + reportVectorizationFailure("Store instruction cannot be vectorized", + "store instruction cannot be vectorized", + "CantVectorizeStore", ST); return false; } + // For nontemporal stores, check that a nontemporal vector version is + // supported on the target. + if (ST->getMetadata(LLVMContext::MD_nontemporal)) { + // Arbitrarily try a vector of 2 elements. + Type *VecTy = VectorType::get(T, /*NumElements=*/2); + assert(VecTy && "did not find vectorized version of stored type"); + unsigned Alignment = getLoadStoreAlignment(ST); + if (!TTI->isLegalNTStore(VecTy, Alignment)) { + reportVectorizationFailure( + "nontemporal store instruction cannot be vectorized", + "nontemporal store instruction cannot be vectorized", + "CantVectorizeNontemporalStore", ST); + return false; + } + } + + } else if (auto *LD = dyn_cast<LoadInst>(&I)) { + if (LD->getMetadata(LLVMContext::MD_nontemporal)) { + // For nontemporal loads, check that a nontemporal vector version is + // supported on the target (arbitrarily try a vector of 2 elements). + Type *VecTy = VectorType::get(I.getType(), /*NumElements=*/2); + assert(VecTy && "did not find vectorized version of load type"); + unsigned Alignment = getLoadStoreAlignment(LD); + if (!TTI->isLegalNTLoad(VecTy, Alignment)) { + reportVectorizationFailure( + "nontemporal load instruction cannot be vectorized", + "nontemporal load instruction cannot be vectorized", + "CantVectorizeNontemporalLoad", LD); + return false; + } + } + // FP instructions can allow unsafe algebra, thus vectorizable by // non-IEEE-754 compliant SIMD units. // This applies to floating-point math operations and calls, not memory @@ -797,23 +821,27 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { AllowedExit.insert(&I); continue; } - ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &I) - << "value cannot be used outside the loop"); + reportVectorizationFailure("Value cannot be used outside the loop", + "value cannot be used outside the loop", + "ValueUsedOutsideLoop", &I); return false; } } // next instr. } if (!PrimaryInduction) { - LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); if (Inductions.empty()) { - ORE->emit(createMissedAnalysis("NoInductionVariable") - << "loop induction variable could not be identified"); + reportVectorizationFailure("Did not find one integer induction var", + "loop induction variable could not be identified", + "NoInductionVariable"); return false; } else if (!WidestIndTy) { - ORE->emit(createMissedAnalysis("NoIntegerInductionVariable") - << "integer loop induction variable could not be identified"); + reportVectorizationFailure("Did not find one integer induction var", + "integer loop induction variable could not be identified", + "NoIntegerInductionVariable"); return false; + } else { + LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); } } @@ -839,11 +867,9 @@ bool LoopVectorizationLegality::canVectorizeMemory() { return false; if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { - ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") - << "write to a loop invariant address could not " - "be vectorized"); - LLVM_DEBUG( - dbgs() << "LV: Non vectorizable stores to a uniform address\n"); + reportVectorizationFailure("Stores to a uniform address", + "write to a loop invariant address could not be vectorized", + "CantVectorizeStoreToLoopInvariantAddress"); return false; } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); @@ -925,8 +951,9 @@ bool LoopVectorizationLegality::blockCanBePredicated( bool LoopVectorizationLegality::canVectorizeWithIfConvert() { if (!EnableIfConversion) { - ORE->emit(createMissedAnalysis("IfConversionDisabled") - << "if-conversion is disabled"); + reportVectorizationFailure("If-conversion is disabled", + "if-conversion is disabled", + "IfConversionDisabled"); return false; } @@ -950,21 +977,26 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { for (BasicBlock *BB : TheLoop->blocks()) { // We don't support switch statements inside loops. if (!isa<BranchInst>(BB->getTerminator())) { - ORE->emit(createMissedAnalysis("LoopContainsSwitch", BB->getTerminator()) - << "loop contains a switch statement"); + reportVectorizationFailure("Loop contains a switch statement", + "loop contains a switch statement", + "LoopContainsSwitch", BB->getTerminator()); return false; } // We must be able to predicate all blocks that need to be predicated. if (blockNeedsPredication(BB)) { if (!blockCanBePredicated(BB, SafePointes)) { - ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) - << "control flow cannot be substituted for a select"); + reportVectorizationFailure( + "Control flow cannot be substituted for a select", + "control flow cannot be substituted for a select", + "NoCFGForSelect", BB->getTerminator()); return false; } } else if (BB != Header && !canIfConvertPHINodes(BB)) { - ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) - << "control flow cannot be substituted for a select"); + reportVectorizationFailure( + "Control flow cannot be substituted for a select", + "control flow cannot be substituted for a select", + "NoCFGForSelect", BB->getTerminator()); return false; } } @@ -992,9 +1024,9 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, // We must have a loop in canonical form. Loops with indirectbr in them cannot // be canonicalized. if (!Lp->getLoopPreheader()) { - LLVM_DEBUG(dbgs() << "LV: Loop doesn't have a legal pre-header.\n"); - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("Loop doesn't have a legal pre-header", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -1003,8 +1035,9 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, // We must have a single backedge. if (Lp->getNumBackEdges() != 1) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("The loop must have a single backedge", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -1013,8 +1046,9 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, // We must have a single exiting block. if (!Lp->getExitingBlock()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("The loop must have an exiting block", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -1025,8 +1059,9 @@ bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, // checked at the end of each iteration. With that we can assume that all // instructions in the loop are executed the same number of times. if (Lp->getExitingBlock() != Lp->getLoopLatch()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); + reportVectorizationFailure("The exiting block is not the loop latch", + "loop control flow is not understood by vectorizer", + "CFGNotUnderstood"); if (DoExtraAnalysis) Result = false; else @@ -1087,7 +1122,9 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { assert(UseVPlanNativePath && "VPlan-native path is not enabled."); if (!canVectorizeOuterLoop()) { - LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Unsupported outer loop.\n"); + reportVectorizationFailure("Unsupported outer loop", + "unsupported outer loop", + "UnsupportedOuterLoop"); // TODO: Implement DoExtraAnalysis when subsequent legal checks support // outer loops. return false; @@ -1137,10 +1174,9 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { - ORE->emit(createMissedAnalysis("TooManySCEVRunTimeChecks") - << "Too many SCEV assumptions need to be made and checked " - << "at runtime"); - LLVM_DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + reportVectorizationFailure("Too many SCEV checks needed", + "Too many SCEV assumptions need to be made and checked at runtime", + "TooManySCEVRunTimeChecks"); if (DoExtraAnalysis) Result = false; else @@ -1159,20 +1195,20 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n"); if (!PrimaryInduction) { - ORE->emit(createMissedAnalysis("NoPrimaryInduction") - << "Missing a primary induction variable in the loop, which is " - << "needed in order to fold tail by masking as required."); - LLVM_DEBUG(dbgs() << "LV: No primary induction, cannot fold tail by " - << "masking.\n"); + reportVectorizationFailure( + "No primary induction, cannot fold tail by masking", + "Missing a primary induction variable in the loop, which is " + "needed in order to fold tail by masking as required.", + "NoPrimaryInduction"); return false; } // TODO: handle reductions when tail is folded by masking. if (!Reductions.empty()) { - ORE->emit(createMissedAnalysis("ReductionFoldingTailByMasking") - << "Cannot fold tail by masking in the presence of reductions."); - LLVM_DEBUG(dbgs() << "LV: Loop has reductions, cannot fold tail by " - << "masking.\n"); + reportVectorizationFailure( + "Loop has reductions, cannot fold tail by masking", + "Cannot fold tail by masking in the presence of reductions.", + "ReductionFoldingTailByMasking"); return false; } @@ -1183,10 +1219,10 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { Instruction *UI = cast<Instruction>(U); if (TheLoop->contains(UI)) continue; - ORE->emit(createMissedAnalysis("LiveOutFoldingTailByMasking") - << "Cannot fold tail by masking in the presence of live outs."); - LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking, loop has an " - << "outside user for : " << *UI << '\n'); + reportVectorizationFailure( + "Cannot fold tail by masking, loop has an outside user for", + "Cannot fold tail by masking in the presence of live outs.", + "LiveOutFoldingTailByMasking", UI); return false; } } @@ -1198,9 +1234,10 @@ bool LoopVectorizationLegality::canFoldTailByMasking() { // do not need predication such as the header block. for (BasicBlock *BB : TheLoop->blocks()) { if (!blockCanBePredicated(BB, SafePointers)) { - ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) - << "control flow cannot be substituted for a select"); - LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as required.\n"); + reportVectorizationFailure( + "Cannot fold tail by masking as required", + "control flow cannot be substituted for a select", + "NoCFGForSelect", BB->getTerminator()); return false; } } diff --git a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 2aa219064299..97077cce83e3 100644 --- a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -1,9 +1,8 @@ //===- LoopVectorizationPlanner.h - Planner for LoopVectorization ---------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// @@ -172,6 +171,13 @@ struct VectorizationFactor { unsigned Width; // Cost of the loop with that width unsigned Cost; + + // Width 1 means no vectorization, cost 0 means uncomputed cost. + static VectorizationFactor Disabled() { return {1, 0}; } + + bool operator==(const VectorizationFactor &rhs) const { + return Width == rhs.Width && Cost == rhs.Cost; + } }; /// Planner drives the vectorization process after having passed @@ -192,11 +198,9 @@ class LoopVectorizationPlanner { /// The legality analysis. LoopVectorizationLegality *Legal; - /// The profitablity analysis. + /// The profitability analysis. LoopVectorizationCostModel &CM; - using VPlanPtr = std::unique_ptr<VPlan>; - SmallVector<VPlanPtr, 4> VPlans; /// This class is used to enable the VPlan to invoke a method of ILV. This is @@ -222,8 +226,9 @@ public: LoopVectorizationCostModel &CM) : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM) {} - /// Plan how to best vectorize, return the best VF and its cost. - VectorizationFactor plan(bool OptForSize, unsigned UserVF); + /// Plan how to best vectorize, return the best VF and its cost, or None if + /// vectorization and interleaving should be avoided up front. + Optional<VectorizationFactor> plan(bool OptForSize, unsigned UserVF); /// Use the VPlan-native path to plan how to best vectorize, return the best /// VF and its cost. diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index c45dee590b84..46265e3f3e13 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1,9 +1,8 @@ //===- LoopVectorize.cpp - A Loop Vectorizer ------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -57,8 +56,10 @@ #include "llvm/Transforms/Vectorize/LoopVectorize.h" #include "LoopVectorizationPlanner.h" #include "VPRecipeBuilder.h" +#include "VPlan.h" #include "VPlanHCFGBuilder.h" #include "VPlanHCFGTransforms.h" +#include "VPlanPredicator.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -86,7 +87,9 @@ #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -133,6 +136,7 @@ #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/SizeOpts.h" #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include <algorithm> #include <cassert> @@ -256,6 +260,13 @@ cl::opt<bool> EnableVPlanNativePath( cl::desc("Enable VPlan-native vectorization path with " "support for outer loop vectorization.")); +// FIXME: Remove this switch once we have divergence analysis. Currently we +// assume divergent non-backedge branches when this switch is true. +cl::opt<bool> EnableVPlanPredication( + "enable-vplan-predication", cl::init(false), cl::Hidden, + cl::desc("Enable VPlan-native vectorization path predicator with " + "support for outer loop vectorization.")); + // This flag enables the stress testing of the VPlan H-CFG construction in the // VPlan-native vectorization path. It must be used in conjuction with // -enable-vplan-native-path. -vplan-verify-hcfg can also be used to enable the @@ -267,6 +278,13 @@ static cl::opt<bool> VPlanBuildStressTest( "out right after the build (stress test the VPlan H-CFG construction " "in the VPlan-native vectorization path).")); +cl::opt<bool> llvm::EnableLoopInterleaving( + "interleave-loops", cl::init(true), cl::Hidden, + cl::desc("Enable loop interleaving in Loop vectorization passes")); +cl::opt<bool> llvm::EnableLoopVectorization( + "vectorize-loops", cl::init(true), cl::Hidden, + cl::desc("Run the Loop vectorization passes")); + /// A helper function for converting Scalar types to vector types. /// If the incoming type is void, we return void. If the VF is 1, we return /// the scalar type. @@ -311,11 +329,14 @@ static unsigned getReciprocalPredBlockProb() { return 2; } /// A helper function that adds a 'fast' flag to floating-point operations. static Value *addFastMathFlag(Value *V) { - if (isa<FPMathOperator>(V)) { - FastMathFlags Flags; - Flags.setFast(); - cast<Instruction>(V)->setFastMathFlags(Flags); - } + if (isa<FPMathOperator>(V)) + cast<Instruction>(V)->setFastMathFlags(FastMathFlags::getFast()); + return V; +} + +static Value *addFastMathFlag(Value *V, FastMathFlags FMF) { + if (isa<FPMathOperator>(V)) + cast<Instruction>(V)->setFastMathFlags(FMF); return V; } @@ -760,7 +781,7 @@ void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) const DILocation *DIL = Inst->getDebugLoc(); if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && !isa<DbgInfoIntrinsic>(Inst)) { - auto NewDIL = DIL->cloneWithDuplicationFactor(UF * VF); + auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(UF * VF); if (NewDIL) B.SetCurrentDebugLocation(NewDIL.getValue()); else @@ -836,7 +857,7 @@ public: AC(AC), ORE(ORE), TheFunction(F), Hints(Hints), InterleaveInfo(IAI) {} /// \return An upper bound for the vectorization factor, or None if - /// vectorization should be avoided up front. + /// vectorization and interleaving should be avoided up front. Optional<unsigned> computeMaxVF(bool OptForSize); /// \return The most profitable vectorization factor and the cost of that VF. @@ -1149,6 +1170,18 @@ public: return foldTailByMasking() || Legal->blockNeedsPredication(BB); } + /// Estimate cost of an intrinsic call instruction CI if it were vectorized + /// with factor VF. Return the cost of the instruction, including + /// scalarization overhead if it's needed. + unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF); + + /// Estimate cost of a call instruction CI if it were vectorized with factor + /// VF. Return the cost of the instruction, including scalarization overhead + /// if it's needed. The flag NeedToScalarize shows if the call needs to be + /// scalarized - + /// i.e. either vector version isn't available, or is too expensive. + unsigned getVectorCallCost(CallInst *CI, unsigned VF, bool &NeedToScalarize); + private: unsigned NumPredStores = 0; @@ -1201,6 +1234,10 @@ private: /// element) unsigned getUniformMemOpCost(Instruction *I, unsigned VF); + /// Estimate the overhead of scalarizing an instruction. This is a + /// convenience wrapper for the type-based getScalarizationOverhead API. + unsigned getScalarizationOverhead(Instruction *I, unsigned VF); + /// Returns whether the instruction is a load or store and will be a emitted /// as a vector operation. bool isConsecutiveLoadOrStore(Instruction *I); @@ -1295,6 +1332,30 @@ private: DecisionList WideningDecisions; + /// Returns true if \p V is expected to be vectorized and it needs to be + /// extracted. + bool needsExtract(Value *V, unsigned VF) const { + Instruction *I = dyn_cast<Instruction>(V); + if (VF == 1 || !I || !TheLoop->contains(I) || TheLoop->isLoopInvariant(I)) + return false; + + // Assume we can vectorize V (and hence we need extraction) if the + // scalars are not computed yet. This can happen, because it is called + // via getScalarizationOverhead from setCostBasedWideningDecision, before + // the scalars are collected. That should be a safe assumption in most + // cases, because we check if the operands have vectorizable types + // beforehand in LoopVectorizationLegality. + return Scalars.find(VF) == Scalars.end() || + !isScalarAfterVectorization(I, VF); + }; + + /// Returns a range containing only operands needing to be extracted. + SmallVector<Value *, 4> filterExtractingOperands(Instruction::op_range Ops, + unsigned VF) { + return SmallVector<Value *, 4>(make_filter_range( + Ops, [this, VF](Value *V) { return this->needsExtract(V, VF); })); + } + public: /// The loop that we evaluate. Loop *TheLoop; @@ -1372,12 +1433,6 @@ static bool isExplicitVecOuterLoop(Loop *OuterLp, return false; } - if (!Hints.getWidth()) { - LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No user vector width.\n"); - Hints.emitRemarkWithHints(); - return false; - } - if (Hints.getInterleave() > 1) { // TODO: Interleave support is future work. LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Interleave is not supported for " @@ -1447,12 +1502,13 @@ struct LoopVectorize : public FunctionPass { auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; return Impl.runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AA, *AC, - GetLAA, *ORE); + GetLAA, *ORE, PSI); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -1478,6 +1534,7 @@ struct LoopVectorize : public FunctionPass { AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); } }; @@ -2051,7 +2108,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr, // A[i] = b; // Member of index 0 // A[i+2] = c; // Member of index 2 (Current instruction) // Current pointer is pointed to A[i+2], adjust it to A[i]. - NewPtr = Builder.CreateGEP(NewPtr, Builder.getInt32(-Index)); + NewPtr = Builder.CreateGEP(ScalarTy, NewPtr, Builder.getInt32(-Index)); if (InBounds) cast<GetElementPtrInst>(NewPtr)->setIsInBounds(true); @@ -2093,8 +2150,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr, GroupMask, UndefVec, "wide.masked.vec"); } else - NewLoad = Builder.CreateAlignedLoad(NewPtrs[Part], - Group->getAlignment(), "wide.vec"); + NewLoad = Builder.CreateAlignedLoad(VecTy, NewPtrs[Part], + Group->getAlignment(), "wide.vec"); Group->addMetadata(NewLoad); NewLoads.push_back(NewLoad); } @@ -2239,16 +2296,16 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, // If the address is consecutive but reversed, then the // wide store needs to start at the last vector element. PartPtr = cast<GetElementPtrInst>( - Builder.CreateGEP(Ptr, Builder.getInt32(-Part * VF))); + Builder.CreateGEP(ScalarDataTy, Ptr, Builder.getInt32(-Part * VF))); PartPtr->setIsInBounds(InBounds); PartPtr = cast<GetElementPtrInst>( - Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF))); + Builder.CreateGEP(ScalarDataTy, PartPtr, Builder.getInt32(1 - VF))); PartPtr->setIsInBounds(InBounds); if (isMaskRequired) // Reverse of a null all-one mask is a null mask. Mask[Part] = reverseVector(Mask[Part]); } else { PartPtr = cast<GetElementPtrInst>( - Builder.CreateGEP(Ptr, Builder.getInt32(Part * VF))); + Builder.CreateGEP(ScalarDataTy, Ptr, Builder.getInt32(Part * VF))); PartPtr->setIsInBounds(InBounds); } @@ -2305,7 +2362,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, UndefValue::get(DataTy), "wide.masked.load"); else - NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); + NewLI = + Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load"); // Add metadata to the load, but setVectorValue to the reverse shuffle. addMetadata(NewLI, LI); @@ -2665,7 +2723,7 @@ Value *InnerLoopVectorizer::emitTransformedIndex( assert(isa<SCEVConstant>(Step) && "Expected constant step for pointer induction"); return B.CreateGEP( - nullptr, StartValue, + StartValue->getType()->getPointerElementType(), StartValue, CreateMul(Index, Exp.expandCodeFor(Step, Index->getType(), &*B.GetInsertPoint()))); } @@ -2849,26 +2907,42 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { BCResumeVal->addIncoming(EndValue, MiddleBlock); // Fix the scalar body counter (PHI node). - unsigned BlockIdx = OrigPhi->getBasicBlockIndex(ScalarPH); - // The old induction's phi node in the scalar body needs the truncated // value. for (BasicBlock *BB : LoopBypassBlocks) BCResumeVal->addIncoming(II.getStartValue(), BB); - OrigPhi->setIncomingValue(BlockIdx, BCResumeVal); + OrigPhi->setIncomingValueForBlock(ScalarPH, BCResumeVal); } + // We need the OrigLoop (scalar loop part) latch terminator to help + // produce correct debug info for the middle block BB instructions. + // The legality check stage guarantees that the loop will have a single + // latch. + assert(isa<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()) && + "Scalar loop latch terminator isn't a branch"); + BranchInst *ScalarLatchBr = + cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); + // Add a check in the middle block to see if we have completed // all of the iterations in the first vector loop. // If (N - N%VF) == N, then we *don't* need to run the remainder. // If tail is to be folded, we know we don't need to run the remainder. Value *CmpN = Builder.getTrue(); - if (!Cost->foldTailByMasking()) + if (!Cost->foldTailByMasking()) { CmpN = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, CountRoundDown, "cmp.n", MiddleBlock->getTerminator()); - ReplaceInstWithInst(MiddleBlock->getTerminator(), - BranchInst::Create(ExitBlock, ScalarPH, CmpN)); + + // Here we use the same DebugLoc as the scalar loop latch branch instead + // of the corresponding compare because they may have ended up with + // different line numbers and we want to avoid awkward line stepping while + // debugging. Eg. if the compare has got a line number inside the loop. + cast<Instruction>(CmpN)->setDebugLoc(ScalarLatchBr->getDebugLoc()); + } + + BranchInst *BrInst = BranchInst::Create(ExitBlock, ScalarPH, CmpN); + BrInst->setDebugLoc(ScalarLatchBr->getDebugLoc()); + ReplaceInstWithInst(MiddleBlock->getTerminator(), BrInst); // Get ready to start creating new instructions into the vectorized body. Builder.SetInsertPoint(&*VecBody->getFirstInsertionPt()); @@ -3022,45 +3096,9 @@ static void cse(BasicBlock *BB) { } } -/// Estimate the overhead of scalarizing an instruction. This is a -/// convenience wrapper for the type-based getScalarizationOverhead API. -static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, - const TargetTransformInfo &TTI) { - if (VF == 1) - return 0; - - unsigned Cost = 0; - Type *RetTy = ToVectorTy(I->getType(), VF); - if (!RetTy->isVoidTy() && - (!isa<LoadInst>(I) || - !TTI.supportsEfficientVectorElementLoadStore())) - Cost += TTI.getScalarizationOverhead(RetTy, true, false); - - // Some targets keep addresses scalar. - if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing()) - return Cost; - - if (CallInst *CI = dyn_cast<CallInst>(I)) { - SmallVector<const Value *, 4> Operands(CI->arg_operands()); - Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); - } - else if (!isa<StoreInst>(I) || - !TTI.supportsEfficientVectorElementLoadStore()) { - SmallVector<const Value *, 4> Operands(I->operand_values()); - Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); - } - - return Cost; -} - -// Estimate cost of a call instruction CI if it were vectorized with factor VF. -// Return the cost of the instruction, including scalarization overhead if it's -// needed. The flag NeedToScalarize shows if the call needs to be scalarized - -// i.e. either vector version isn't available, or is too expensive. -static unsigned getVectorCallCost(CallInst *CI, unsigned VF, - const TargetTransformInfo &TTI, - const TargetLibraryInfo *TLI, - bool &NeedToScalarize) { +unsigned LoopVectorizationCostModel::getVectorCallCost(CallInst *CI, + unsigned VF, + bool &NeedToScalarize) { Function *F = CI->getCalledFunction(); StringRef FnName = CI->getCalledFunction()->getName(); Type *ScalarRetTy = CI->getType(); @@ -3083,7 +3121,7 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. - unsigned ScalarizationCost = getScalarizationOverhead(CI, VF, TTI); + unsigned ScalarizationCost = getScalarizationOverhead(CI, VF); unsigned Cost = ScalarCallCost * VF + ScalarizationCost; @@ -3102,12 +3140,8 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, return Cost; } -// Estimate cost of an intrinsic call instruction CI if it were vectorized with -// factor VF. Return the cost of the instruction, including scalarization -// overhead if it's needed. -static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF, - const TargetTransformInfo &TTI, - const TargetLibraryInfo *TLI) { +unsigned LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI, + unsigned VF) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); assert(ID && "Expected intrinsic call!"); @@ -3468,7 +3502,7 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { Start->addIncoming(Incoming, BB); } - Phi->setIncomingValue(Phi->getBasicBlockIndex(LoopScalarPreHeader), Start); + Phi->setIncomingValueForBlock(LoopScalarPreHeader, Start); Phi->setName("scalar.recur"); // Finally, fix users of the recurrence outside the loop. The users will need @@ -3596,14 +3630,23 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = VectorLoopValueMap.getVectorValue(LoopExitInst, 0); unsigned Op = RecurrenceDescriptor::getRecurrenceBinOp(RK); - setDebugLocFromInst(Builder, ReducedPartRdx); + + // The middle block terminator has already been assigned a DebugLoc here (the + // OrigLoop's single latch terminator). We want the whole middle block to + // appear to execute on this line because: (a) it is all compiler generated, + // (b) these instructions are always executed after evaluating the latch + // conditional branch, and (c) other passes may add new predecessors which + // terminate on this line. This is the easiest way to ensure we don't + // accidentally cause an extra step back into the loop while debugging. + setDebugLocFromInst(Builder, LoopMiddleBlock->getTerminator()); for (unsigned Part = 1; Part < UF; ++Part) { Value *RdxPart = VectorLoopValueMap.getVectorValue(LoopExitInst, Part); if (Op != Instruction::ICmp && Op != Instruction::FCmp) // Floating point operations had to be 'fast' to enable the reduction. ReducedPartRdx = addFastMathFlag( Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart, - ReducedPartRdx, "bin.rdx")); + ReducedPartRdx, "bin.rdx"), + RdxDesc.getFastMathFlags()); else ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx, RdxPart); @@ -3935,9 +3978,11 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { // Create the new GEP. Note that this GEP may be a scalar if VF == 1, // but it should be a vector, otherwise. - auto *NewGEP = GEP->isInBounds() - ? Builder.CreateInBoundsGEP(Ptr, Indices) - : Builder.CreateGEP(Ptr, Indices); + auto *NewGEP = + GEP->isInBounds() + ? Builder.CreateInBoundsGEP(GEP->getSourceElementType(), Ptr, + Indices) + : Builder.CreateGEP(GEP->getSourceElementType(), Ptr, Indices); assert((VF == 1 || NewGEP->getType()->isVectorTy()) && "NewGEP is not a pointer vector"); VectorLoopValueMap.setVectorValue(&I, Part, NewGEP); @@ -3955,6 +4000,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: + case Instruction::FNeg: case Instruction::Mul: case Instruction::FMul: case Instruction::FDiv: @@ -3965,21 +4011,22 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { - // Just widen binops. - auto *BinOp = cast<BinaryOperator>(&I); - setDebugLocFromInst(Builder, BinOp); + // Just widen unops and binops. + setDebugLocFromInst(Builder, &I); for (unsigned Part = 0; Part < UF; ++Part) { - Value *A = getOrCreateVectorValue(BinOp->getOperand(0), Part); - Value *B = getOrCreateVectorValue(BinOp->getOperand(1), Part); - Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A, B); + SmallVector<Value *, 2> Ops; + for (Value *Op : I.operands()) + Ops.push_back(getOrCreateVectorValue(Op, Part)); - if (BinaryOperator *VecOp = dyn_cast<BinaryOperator>(V)) - VecOp->copyIRFlags(BinOp); + Value *V = Builder.CreateNAryOp(I.getOpcode(), Ops); + + if (auto *VecOp = dyn_cast<Instruction>(V)) + VecOp->copyIRFlags(&I); // Use this vector value for all users of the original instruction. VectorLoopValueMap.setVectorValue(&I, Part, V); - addMetadata(V, BinOp); + addMetadata(V, &I); } break; @@ -4088,9 +4135,9 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { // version of the instruction. // Is it beneficial to perform intrinsic call compared to lib call? bool NeedToScalarize; - unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + unsigned CallCost = Cost->getVectorCallCost(CI, VF, NeedToScalarize); bool UseVectorIntrinsic = - ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; + ID && Cost->getVectorIntrinsicCost(CI, VF) <= CallCost; assert((UseVectorIntrinsic || !NeedToScalarize) && "Instruction should be scalarized elsewhere."); @@ -4395,6 +4442,13 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I, auto *Group = getInterleavedAccessGroup(I); assert(Group && "Must have a group."); + // If the instruction's allocated size doesn't equal it's type size, it + // requires padding and will be scalarized. + auto &DL = I->getModule()->getDataLayout(); + auto *ScalarTy = getMemInstValueType(I); + if (hasIrregularType(ScalarTy, DL, VF)) + return false; + // Check if masking is required. // A Group may need masking for one of two reasons: it resides in a block that // needs predication, or it was decided to use masking to deal with gaps. @@ -4987,6 +5041,8 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, if (LoopCost == 0) LoopCost = expectedCost(VF).first; + assert(LoopCost && "Non-zero loop cost expected"); + // Clamp the calculated IC to be between the 1 and the max interleave count // that the target allows. if (IC > MaxInterleaveCount) @@ -5314,15 +5370,6 @@ int LoopVectorizationCostModel::computePredInstDiscount( return true; }; - // Returns true if an operand that cannot be scalarized must be extracted - // from a vector. We will account for this scalarization overhead below. Note - // that the non-void predicated instructions are placed in their own blocks, - // and their return values are inserted into vectors. Thus, an extract would - // still be required. - auto needsExtract = [&](Instruction *I) -> bool { - return TheLoop->contains(I) && !isScalarAfterVectorization(I, VF); - }; - // Compute the expected cost discount from scalarizing the entire expression // feeding the predicated instruction. We currently only consider expressions // that are single-use instruction chains. @@ -5362,7 +5409,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( "Instruction has non-scalar type"); if (canBeScalarized(J)) Worklist.push_back(J); - else if (needsExtract(J)) + else if (needsExtract(J, VF)) ScalarCost += TTI.getScalarizationOverhead( ToVectorTy(J->getType(),VF), false, true); } @@ -5484,7 +5531,7 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // Get the overhead of the extractelement and insertelement instructions // we might create due to scalarization. - Cost += getScalarizationOverhead(I, VF, TTI); + Cost += getScalarizationOverhead(I, VF); // If we have a predicated store, it may not be executed for each vector // lane. Scale the cost by the probability of executing the predicated @@ -5636,6 +5683,36 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { return VectorizationCostTy(C, TypeNotScalarized); } +unsigned LoopVectorizationCostModel::getScalarizationOverhead(Instruction *I, + unsigned VF) { + + if (VF == 1) + return 0; + + unsigned Cost = 0; + Type *RetTy = ToVectorTy(I->getType(), VF); + if (!RetTy->isVoidTy() && + (!isa<LoadInst>(I) || !TTI.supportsEfficientVectorElementLoadStore())) + Cost += TTI.getScalarizationOverhead(RetTy, true, false); + + // Some targets keep addresses scalar. + if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing()) + return Cost; + + // Some targets support efficient element stores. + if (isa<StoreInst>(I) && TTI.supportsEfficientVectorElementLoadStore()) + return Cost; + + // Collect operands to consider. + CallInst *CI = dyn_cast<CallInst>(I); + Instruction::op_range Ops = CI ? CI->arg_operands() : I->operands(); + + // Skip operands that do not require extraction/scalarization and do not incur + // any overhead. + return Cost + TTI.getOperandsScalarizationOverhead( + filterExtractingOperands(Ops, VF), VF); +} + void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { if (VF == 1) return; @@ -5876,7 +5953,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, // The cost of insertelement and extractelement instructions needed for // scalarization. - Cost += getScalarizationOverhead(I, VF, TTI); + Cost += getScalarizationOverhead(I, VF); // Scale the cost by the probability of executing the predicated blocks. // This assumes the predicated block for each vector lane is equally @@ -5916,6 +5993,14 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, I->getOpcode(), VectorTy, TargetTransformInfo::OK_AnyValue, Op2VK, TargetTransformInfo::OP_None, Op2VP, Operands); } + case Instruction::FNeg: { + unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; + return N * TTI.getArithmeticInstrCost( + I->getOpcode(), VectorTy, TargetTransformInfo::OK_AnyValue, + TargetTransformInfo::OK_AnyValue, + TargetTransformInfo::OP_None, TargetTransformInfo::OP_None, + I->getOperand(0)); + } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); const SCEV *CondSCEV = SE->getSCEV(SI->getCondition()); @@ -5997,16 +6082,16 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, case Instruction::Call: { bool NeedToScalarize; CallInst *CI = cast<CallInst>(I); - unsigned CallCost = getVectorCallCost(CI, VF, TTI, TLI, NeedToScalarize); + unsigned CallCost = getVectorCallCost(CI, VF, NeedToScalarize); if (getVectorIntrinsicIDForCall(CI, TLI)) - return std::min(CallCost, getVectorIntrinsicCost(CI, VF, TTI, TLI)); + return std::min(CallCost, getVectorIntrinsicCost(CI, VF)); return CallCost; } default: // The cost of executing VF copies of the scalar instruction. This opcode // is unknown. Assume that it is the same as 'mul'. return VF * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy) + - getScalarizationOverhead(I, VF, TTI); + getScalarizationOverhead(I, VF); } // end of switch. } @@ -6027,10 +6112,13 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { +Pass *createLoopVectorizePass() { return new LoopVectorize(); } + Pass *createLoopVectorizePass(bool InterleaveOnlyWhenForced, bool VectorizeOnlyWhenForced) { return new LoopVectorize(InterleaveOnlyWhenForced, VectorizeOnlyWhenForced); @@ -6066,50 +6154,65 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { } } +// TODO: we could return a pair of values that specify the max VF and +// min VF, to be used in `buildVPlans(MinVF, MaxVF)` instead of +// `buildVPlans(VF, VF)`. We cannot do it because VPLAN at the moment +// doesn't have a cost model that can choose which plan to execute if +// more than one is generated. +static unsigned determineVPlanVF(const unsigned WidestVectorRegBits, + LoopVectorizationCostModel &CM) { + unsigned WidestType; + std::tie(std::ignore, WidestType) = CM.getSmallestAndWidestTypes(); + return WidestVectorRegBits / WidestType; +} + VectorizationFactor LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize, unsigned UserVF) { - // Width 1 means no vectorization, cost 0 means uncomputed cost. - const VectorizationFactor NoVectorization = {1U, 0U}; - + unsigned VF = UserVF; // Outer loop handling: They may require CFG and instruction level // transformations before even evaluating whether vectorization is profitable. // Since we cannot modify the incoming IR, we need to build VPlan upfront in // the vectorization pipeline. if (!OrigLoop->empty()) { - // TODO: If UserVF is not provided, we set UserVF to 4 for stress testing. - // This won't be necessary when UserVF is not required in the VPlan-native - // path. - if (VPlanBuildStressTest && !UserVF) - UserVF = 4; - + // If the user doesn't provide a vectorization factor, determine a + // reasonable one. + if (!UserVF) { + VF = determineVPlanVF(TTI->getRegisterBitWidth(true /* Vector*/), CM); + LLVM_DEBUG(dbgs() << "LV: VPlan computed VF " << VF << ".\n"); + + // Make sure we have a VF > 1 for stress testing. + if (VPlanBuildStressTest && VF < 2) { + LLVM_DEBUG(dbgs() << "LV: VPlan stress testing: " + << "overriding computed VF.\n"); + VF = 4; + } + } assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); - assert(UserVF && "Expected UserVF for outer loop vectorization."); - assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); - LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); - buildVPlans(UserVF, UserVF); + assert(isPowerOf2_32(VF) && "VF needs to be a power of two"); + LLVM_DEBUG(dbgs() << "LV: Using " << (UserVF ? "user " : "") << "VF " << VF + << " to build VPlans.\n"); + buildVPlans(VF, VF); // For VPlan build stress testing, we bail out after VPlan construction. if (VPlanBuildStressTest) - return NoVectorization; + return VectorizationFactor::Disabled(); - return {UserVF, 0}; + return {VF, 0}; } LLVM_DEBUG( dbgs() << "LV: Not vectorizing. Inner loops aren't supported in the " "VPlan-native path.\n"); - return NoVectorization; + return VectorizationFactor::Disabled(); } -VectorizationFactor -LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { +Optional<VectorizationFactor> LoopVectorizationPlanner::plan(bool OptForSize, + unsigned UserVF) { assert(OrigLoop->empty() && "Inner loop expected."); - // Width 1 means no vectorization, cost 0 means uncomputed cost. - const VectorizationFactor NoVectorization = {1U, 0U}; Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize); - if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize. - return NoVectorization; + if (!MaybeMaxVF) // Cases that should not to be vectorized nor interleaved. + return None; // Invalidate interleave groups if all blocks of loop will be predicated. if (CM.blockNeedsPredication(OrigLoop->getHeader()) && @@ -6129,7 +6232,7 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { CM.selectUserVectorizationFactor(UserVF); buildVPlansWithVPRecipes(UserVF, UserVF); LLVM_DEBUG(printPlans(dbgs())); - return {UserVF, 0}; + return {{UserVF, 0}}; } unsigned MaxVF = MaybeMaxVF.getValue(); @@ -6148,7 +6251,7 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { buildVPlansWithVPRecipes(1, MaxVF); LLVM_DEBUG(printPlans(dbgs())); if (MaxVF == 1) - return NoVectorization; + return VectorizationFactor::Disabled(); // Select the optimal vectorization factor. return CM.selectVectorizationFactor(MaxVF); @@ -6527,6 +6630,7 @@ bool VPRecipeBuilder::tryToWiden(Instruction *I, VPBasicBlock *VPBB, case Instruction::FCmp: case Instruction::FDiv: case Instruction::FMul: + case Instruction::FNeg: case Instruction::FPExt: case Instruction::FPToSI: case Instruction::FPToUI: @@ -6582,9 +6686,9 @@ bool VPRecipeBuilder::tryToWiden(Instruction *I, VPBasicBlock *VPBB, // version of the instruction. // Is it beneficial to perform intrinsic call compared to lib call? bool NeedToScalarize; - unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + unsigned CallCost = CM.getVectorCallCost(CI, VF, NeedToScalarize); bool UseVectorIntrinsic = - ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; + ID && CM.getVectorIntrinsicCost(CI, VF) <= CallCost; return UseVectorIntrinsic || !NeedToScalarize; } if (isa<LoadInst>(I) || isa<StoreInst>(I)) { @@ -6756,8 +6860,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF, } } -LoopVectorizationPlanner::VPlanPtr -LoopVectorizationPlanner::buildVPlanWithVPRecipes( +VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( VFRange &Range, SmallPtrSetImpl<Value *> &NeedDef, SmallPtrSetImpl<Instruction *> &DeadInstructions) { // Hold a mapping from predicated instructions to their recipes, in order to @@ -6772,7 +6875,7 @@ LoopVectorizationPlanner::buildVPlanWithVPRecipes( VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); auto Plan = llvm::make_unique<VPlan>(VPBB); - VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, TTI, Legal, CM, Builder); + VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, Legal, CM, Builder); // Represent values that will have defs inside VPlan. for (Value *V : NeedDef) Plan->addVPValue(V); @@ -6881,8 +6984,7 @@ LoopVectorizationPlanner::buildVPlanWithVPRecipes( return Plan; } -LoopVectorizationPlanner::VPlanPtr -LoopVectorizationPlanner::buildVPlan(VFRange &Range) { +VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { // Outer loop handling: They may require CFG and instruction level // transformations before even evaluating whether vectorization is profitable. // Since we cannot modify the incoming IR, we need to build VPlan upfront in @@ -6897,13 +6999,22 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range) { VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan); HCFGBuilder.buildHierarchicalCFG(); + for (unsigned VF = Range.Start; VF < Range.End; VF *= 2) + Plan->addVF(VF); + + if (EnableVPlanPredication) { + VPlanPredicator VPP(*Plan); + VPP.predicate(); + + // Avoid running transformation to recipes until masked code generation in + // VPlan-native path is in place. + return Plan; + } + SmallPtrSet<Instruction *, 1> DeadInstructions; VPlanHCFGTransforms::VPInstructionsToVPRecipes( Plan, Legal->getInductionVars(), DeadInstructions); - for (unsigned VF = Range.Start; VF < Range.End; VF *= 2) - Plan->addVF(VF); - return Plan; } @@ -7096,7 +7207,8 @@ static bool processLoopInVPlanNativePath( Loop *L, PredicatedScalarEvolution &PSE, LoopInfo *LI, DominatorTree *DT, LoopVectorizationLegality *LVL, TargetTransformInfo *TTI, TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, - OptimizationRemarkEmitter *ORE, LoopVectorizeHints &Hints) { + OptimizationRemarkEmitter *ORE, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, LoopVectorizeHints &Hints) { assert(EnableVPlanNativePath && "VPlan-native path is disabled."); Function *F = L->getHeader()->getParent(); @@ -7109,24 +7221,28 @@ static bool processLoopInVPlanNativePath( LoopVectorizationPlanner LVP(L, LI, TLI, TTI, LVL, CM); // Get user vectorization factor. - unsigned UserVF = Hints.getWidth(); + const unsigned UserVF = Hints.getWidth(); - // Check the function attributes to find out if this function should be - // optimized for size. + // Check the function attributes and profiles to find out if this function + // should be optimized for size. bool OptForSize = - Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + Hints.getForce() != LoopVectorizeHints::FK_Enabled && + (F->hasOptSize() || + llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI)); // Plan how to best vectorize, return the best VF and its cost. - VectorizationFactor VF = LVP.planInVPlanNativePath(OptForSize, UserVF); + const VectorizationFactor VF = LVP.planInVPlanNativePath(OptForSize, UserVF); // If we are stress testing VPlan builds, do not attempt to generate vector - // code. - if (VPlanBuildStressTest) + // code. Masked vector code generation support will follow soon. + // Also, do not attempt to vectorize if no vector code will be produced. + if (VPlanBuildStressTest || EnableVPlanPredication || + VectorizationFactor::Disabled() == VF) return false; LVP.setBestPlan(VF.Width, 1); - InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, UserVF, 1, LVL, + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, 1, LVL, &CM); LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" << L->getHeader()->getParent()->getName() << "\"\n"); @@ -7184,7 +7300,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements(*ORE); - LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, GetLAA, LI, ORE, + LoopVectorizationLegality LVL(L, PSE, DT, TTI, TLI, AA, F, GetLAA, LI, ORE, &Requirements, &Hints, DB, AC); if (!LVL.canVectorize(EnableVPlanNativePath)) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); @@ -7192,10 +7308,12 @@ bool LoopVectorizePass::processLoop(Loop *L) { return false; } - // Check the function attributes to find out if this function should be - // optimized for size. + // Check the function attributes and profiles to find out if this function + // should be optimized for size. bool OptForSize = - Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + Hints.getForce() != LoopVectorizeHints::FK_Enabled && + (F->hasOptSize() || + llvm::shouldOptimizeForSize(L->getHeader(), PSI, BFI)); // Entrance to the VPlan-native vectorization path. Outer loops are processed // here. They may require CFG and instruction level transformations before @@ -7204,7 +7322,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // pipeline. if (!L->empty()) return processLoopInVPlanNativePath(L, PSE, LI, DT, &LVL, TTI, TLI, DB, AC, - ORE, Hints); + ORE, BFI, PSI, Hints); assert(L->empty() && "Inner loop expected."); // Check the loop for a trip count threshold: vectorize loops with a tiny trip @@ -7304,14 +7422,18 @@ bool LoopVectorizePass::processLoop(Loop *L) { unsigned UserVF = Hints.getWidth(); // Plan how to best vectorize, return the best VF and its cost. - VectorizationFactor VF = LVP.plan(OptForSize, UserVF); + Optional<VectorizationFactor> MaybeVF = LVP.plan(OptForSize, UserVF); - // Select the interleave count. - unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); - - // Get user interleave count. + VectorizationFactor VF = VectorizationFactor::Disabled(); + unsigned IC = 1; unsigned UserIC = Hints.getInterleave(); + if (MaybeVF) { + VF = *MaybeVF; + // Select the interleave count. + IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); + } + // Identify the diagnostic messages that should be produced. std::pair<StringRef, std::string> VecDiagMsg, IntDiagMsg; bool VectorizeLoop = true, InterleaveLoop = true; @@ -7330,7 +7452,16 @@ bool LoopVectorizePass::processLoop(Loop *L) { VectorizeLoop = false; } - if (IC == 1 && UserIC <= 1) { + if (!MaybeVF && UserIC > 1) { + // Tell the user interleaving was avoided up-front, despite being explicitly + // requested. + LLVM_DEBUG(dbgs() << "LV: Ignoring UserIC, because vectorization and " + "interleaving should be avoided up front\n"); + IntDiagMsg = std::make_pair( + "InterleavingAvoided", + "Ignoring UserIC, because interleaving was avoided up front"); + InterleaveLoop = false; + } else if (IC == 1 && UserIC <= 1) { // Tell the user interleaving is not beneficial. LLVM_DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); IntDiagMsg = std::make_pair( @@ -7457,7 +7588,7 @@ bool LoopVectorizePass::runImpl( DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_, DemandedBits &DB_, AliasAnalysis &AA_, AssumptionCache &AC_, std::function<const LoopAccessInfo &(Loop &)> &GetLAA_, - OptimizationRemarkEmitter &ORE_) { + OptimizationRemarkEmitter &ORE_, ProfileSummaryInfo *PSI_) { SE = &SE_; LI = &LI_; TTI = &TTI_; @@ -7469,6 +7600,7 @@ bool LoopVectorizePass::runImpl( GetLAA = &GetLAA_; DB = &DB_; ORE = &ORE_; + PSI = PSI_; // Don't attempt if // 1. the target claims to have no vector registers, and @@ -7488,7 +7620,8 @@ bool LoopVectorizePass::runImpl( // will simplify all loops, regardless of whether anything end up being // vectorized. for (auto &L : *LI) - Changed |= simplifyLoop(L, DT, LI, SE, AC, false /* PreserveLCSSA */); + Changed |= + simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); // Build up a worklist of inner-loops to vectorize. This is necessary as // the act of vectorizing or partially unrolling a loop creates new loops @@ -7527,15 +7660,22 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DB = AM.getResult<DemandedBitsAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + MemorySSA *MSSA = EnableMSSALoopDependency + ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() + : nullptr; auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, nullptr}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, MSSA}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; + const ModuleAnalysisManager &MAM = + AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + ProfileSummaryInfo *PSI = + MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); bool Changed = - runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AA, AC, GetLAA, ORE); + runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AA, AC, GetLAA, ORE, PSI); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 2e856a7e6802..27a86c0bca91 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1,9 +1,8 @@ //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -106,6 +105,10 @@ using namespace slpvectorizer; STATISTIC(NumVectorInstructions, "Number of vector instructions generated"); +cl::opt<bool> + llvm::RunSLPVectorization("vectorize-slp", cl::init(false), cl::Hidden, + cl::desc("Run the SLP vectorization passes")); + static cl::opt<int> SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden, cl::desc("Only vectorize if you gain more than this " @@ -207,6 +210,13 @@ static bool isSplat(ArrayRef<Value *> VL) { return true; } +/// \returns True if \p I is commutative, handles CmpInst as well as Instruction. +static bool isCommutative(Instruction *I) { + if (auto *IC = dyn_cast<CmpInst>(I)) + return IC->isCommutative(); + return I->isCommutative(); +} + /// Checks if the vector of instructions can be represented as a shuffle, like: /// %x0 = extractelement <4 x i8> %x, i32 0 /// %x3 = extractelement <4 x i8> %x, i32 3 @@ -438,8 +448,9 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, case Instruction::Call: { CallInst *CI = cast<CallInst>(UserInst); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - if (hasVectorInstrinsicScalarOpd(ID, 1)) { - return (CI->getArgOperand(1) == Scalar); + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { + if (hasVectorInstrinsicScalarOpd(ID, i)) + return (CI->getArgOperand(i) == Scalar); } LLVM_FALLTHROUGH; } @@ -474,6 +485,8 @@ namespace slpvectorizer { /// Bottom Up SLP Vectorizer. class BoUpSLP { + struct TreeEntry; + public: using ValueList = SmallVector<Value *, 8>; using InstrList = SmallVector<Instruction *, 16>; @@ -517,7 +530,7 @@ public: /// \returns the cost incurred by unwanted spills and fills, caused by /// holding live values over call sites. - int getSpillCost(); + int getSpillCost() const; /// \returns the vectorization cost of the subtree that starts at \p VL. /// A negative number means that this is profitable. @@ -576,7 +589,7 @@ public: /// the stored value. Otherwise, the size is the width of the largest loaded /// value reaching V. This method is used by the vectorizer to calculate /// vectorization factors. - unsigned getVectorElementSize(Value *V); + unsigned getVectorElementSize(Value *V) const; /// Compute the minimum type sizes required to represent the entries in a /// vectorizable tree. @@ -599,13 +612,512 @@ public: /// \returns True if the VectorizableTree is both tiny and not fully /// vectorizable. We do not vectorize such trees. - bool isTreeTinyAndNotFullyVectorizable(); + bool isTreeTinyAndNotFullyVectorizable() const; OptimizationRemarkEmitter *getORE() { return ORE; } -private: - struct TreeEntry; + /// This structure holds any data we need about the edges being traversed + /// during buildTree_rec(). We keep track of: + /// (i) the user TreeEntry index, and + /// (ii) the index of the edge. + struct EdgeInfo { + EdgeInfo() = default; + EdgeInfo(TreeEntry *UserTE, unsigned EdgeIdx) + : UserTE(UserTE), EdgeIdx(EdgeIdx) {} + /// The user TreeEntry. + TreeEntry *UserTE = nullptr; + /// The operand index of the use. + unsigned EdgeIdx = UINT_MAX; +#ifndef NDEBUG + friend inline raw_ostream &operator<<(raw_ostream &OS, + const BoUpSLP::EdgeInfo &EI) { + EI.dump(OS); + return OS; + } + /// Debug print. + void dump(raw_ostream &OS) const { + OS << "{User:" << (UserTE ? std::to_string(UserTE->Idx) : "null") + << " EdgeIdx:" << EdgeIdx << "}"; + } + LLVM_DUMP_METHOD void dump() const { dump(dbgs()); } +#endif + }; + + /// A helper data structure to hold the operands of a vector of instructions. + /// This supports a fixed vector length for all operand vectors. + class VLOperands { + /// For each operand we need (i) the value, and (ii) the opcode that it + /// would be attached to if the expression was in a left-linearized form. + /// This is required to avoid illegal operand reordering. + /// For example: + /// \verbatim + /// 0 Op1 + /// |/ + /// Op1 Op2 Linearized + Op2 + /// \ / ----------> |/ + /// - - + /// + /// Op1 - Op2 (0 + Op1) - Op2 + /// \endverbatim + /// + /// Value Op1 is attached to a '+' operation, and Op2 to a '-'. + /// + /// Another way to think of this is to track all the operations across the + /// path from the operand all the way to the root of the tree and to + /// calculate the operation that corresponds to this path. For example, the + /// path from Op2 to the root crosses the RHS of the '-', therefore the + /// corresponding operation is a '-' (which matches the one in the + /// linearized tree, as shown above). + /// + /// For lack of a better term, we refer to this operation as Accumulated + /// Path Operation (APO). + struct OperandData { + OperandData() = default; + OperandData(Value *V, bool APO, bool IsUsed) + : V(V), APO(APO), IsUsed(IsUsed) {} + /// The operand value. + Value *V = nullptr; + /// TreeEntries only allow a single opcode, or an alternate sequence of + /// them (e.g, +, -). Therefore, we can safely use a boolean value for the + /// APO. It is set to 'true' if 'V' is attached to an inverse operation + /// in the left-linearized form (e.g., Sub/Div), and 'false' otherwise + /// (e.g., Add/Mul) + bool APO = false; + /// Helper data for the reordering function. + bool IsUsed = false; + }; + + /// During operand reordering, we are trying to select the operand at lane + /// that matches best with the operand at the neighboring lane. Our + /// selection is based on the type of value we are looking for. For example, + /// if the neighboring lane has a load, we need to look for a load that is + /// accessing a consecutive address. These strategies are summarized in the + /// 'ReorderingMode' enumerator. + enum class ReorderingMode { + Load, ///< Matching loads to consecutive memory addresses + Opcode, ///< Matching instructions based on opcode (same or alternate) + Constant, ///< Matching constants + Splat, ///< Matching the same instruction multiple times (broadcast) + Failed, ///< We failed to create a vectorizable group + }; + + using OperandDataVec = SmallVector<OperandData, 2>; + + /// A vector of operand vectors. + SmallVector<OperandDataVec, 4> OpsVec; + + const DataLayout &DL; + ScalarEvolution &SE; + + /// \returns the operand data at \p OpIdx and \p Lane. + OperandData &getData(unsigned OpIdx, unsigned Lane) { + return OpsVec[OpIdx][Lane]; + } + + /// \returns the operand data at \p OpIdx and \p Lane. Const version. + const OperandData &getData(unsigned OpIdx, unsigned Lane) const { + return OpsVec[OpIdx][Lane]; + } + + /// Clears the used flag for all entries. + void clearUsed() { + for (unsigned OpIdx = 0, NumOperands = getNumOperands(); + OpIdx != NumOperands; ++OpIdx) + for (unsigned Lane = 0, NumLanes = getNumLanes(); Lane != NumLanes; + ++Lane) + OpsVec[OpIdx][Lane].IsUsed = false; + } + + /// Swap the operand at \p OpIdx1 with that one at \p OpIdx2. + void swap(unsigned OpIdx1, unsigned OpIdx2, unsigned Lane) { + std::swap(OpsVec[OpIdx1][Lane], OpsVec[OpIdx2][Lane]); + } + + // Search all operands in Ops[*][Lane] for the one that matches best + // Ops[OpIdx][LastLane] and return its opreand index. + // If no good match can be found, return None. + Optional<unsigned> + getBestOperand(unsigned OpIdx, int Lane, int LastLane, + ArrayRef<ReorderingMode> ReorderingModes) { + unsigned NumOperands = getNumOperands(); + + // The operand of the previous lane at OpIdx. + Value *OpLastLane = getData(OpIdx, LastLane).V; + + // Our strategy mode for OpIdx. + ReorderingMode RMode = ReorderingModes[OpIdx]; + + // The linearized opcode of the operand at OpIdx, Lane. + bool OpIdxAPO = getData(OpIdx, Lane).APO; + + const unsigned BestScore = 2; + const unsigned GoodScore = 1; + + // The best operand index and its score. + // Sometimes we have more than one option (e.g., Opcode and Undefs), so we + // are using the score to differentiate between the two. + struct BestOpData { + Optional<unsigned> Idx = None; + unsigned Score = 0; + } BestOp; + + // Iterate through all unused operands and look for the best. + for (unsigned Idx = 0; Idx != NumOperands; ++Idx) { + // Get the operand at Idx and Lane. + OperandData &OpData = getData(Idx, Lane); + Value *Op = OpData.V; + bool OpAPO = OpData.APO; + + // Skip already selected operands. + if (OpData.IsUsed) + continue; + + // Skip if we are trying to move the operand to a position with a + // different opcode in the linearized tree form. This would break the + // semantics. + if (OpAPO != OpIdxAPO) + continue; + + // Look for an operand that matches the current mode. + switch (RMode) { + case ReorderingMode::Load: + if (isa<LoadInst>(Op)) { + // Figure out which is left and right, so that we can check for + // consecutive loads + bool LeftToRight = Lane > LastLane; + Value *OpLeft = (LeftToRight) ? OpLastLane : Op; + Value *OpRight = (LeftToRight) ? Op : OpLastLane; + if (isConsecutiveAccess(cast<LoadInst>(OpLeft), + cast<LoadInst>(OpRight), DL, SE)) + BestOp.Idx = Idx; + } + break; + case ReorderingMode::Opcode: + // We accept both Instructions and Undefs, but with different scores. + if ((isa<Instruction>(Op) && isa<Instruction>(OpLastLane) && + cast<Instruction>(Op)->getOpcode() == + cast<Instruction>(OpLastLane)->getOpcode()) || + (isa<UndefValue>(OpLastLane) && isa<Instruction>(Op)) || + isa<UndefValue>(Op)) { + // An instruction has a higher score than an undef. + unsigned Score = (isa<UndefValue>(Op)) ? GoodScore : BestScore; + if (Score > BestOp.Score) { + BestOp.Idx = Idx; + BestOp.Score = Score; + } + } + break; + case ReorderingMode::Constant: + if (isa<Constant>(Op)) { + unsigned Score = (isa<UndefValue>(Op)) ? GoodScore : BestScore; + if (Score > BestOp.Score) { + BestOp.Idx = Idx; + BestOp.Score = Score; + } + } + break; + case ReorderingMode::Splat: + if (Op == OpLastLane) + BestOp.Idx = Idx; + break; + case ReorderingMode::Failed: + return None; + } + } + + if (BestOp.Idx) { + getData(BestOp.Idx.getValue(), Lane).IsUsed = true; + return BestOp.Idx; + } + // If we could not find a good match return None. + return None; + } + + /// Helper for reorderOperandVecs. \Returns the lane that we should start + /// reordering from. This is the one which has the least number of operands + /// that can freely move about. + unsigned getBestLaneToStartReordering() const { + unsigned BestLane = 0; + unsigned Min = UINT_MAX; + for (unsigned Lane = 0, NumLanes = getNumLanes(); Lane != NumLanes; + ++Lane) { + unsigned NumFreeOps = getMaxNumOperandsThatCanBeReordered(Lane); + if (NumFreeOps < Min) { + Min = NumFreeOps; + BestLane = Lane; + } + } + return BestLane; + } + + /// \Returns the maximum number of operands that are allowed to be reordered + /// for \p Lane. This is used as a heuristic for selecting the first lane to + /// start operand reordering. + unsigned getMaxNumOperandsThatCanBeReordered(unsigned Lane) const { + unsigned CntTrue = 0; + unsigned NumOperands = getNumOperands(); + // Operands with the same APO can be reordered. We therefore need to count + // how many of them we have for each APO, like this: Cnt[APO] = x. + // Since we only have two APOs, namely true and false, we can avoid using + // a map. Instead we can simply count the number of operands that + // correspond to one of them (in this case the 'true' APO), and calculate + // the other by subtracting it from the total number of operands. + for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) + if (getData(OpIdx, Lane).APO) + ++CntTrue; + unsigned CntFalse = NumOperands - CntTrue; + return std::max(CntTrue, CntFalse); + } + + /// Go through the instructions in VL and append their operands. + void appendOperandsOfVL(ArrayRef<Value *> VL) { + assert(!VL.empty() && "Bad VL"); + assert((empty() || VL.size() == getNumLanes()) && + "Expected same number of lanes"); + assert(isa<Instruction>(VL[0]) && "Expected instruction"); + unsigned NumOperands = cast<Instruction>(VL[0])->getNumOperands(); + OpsVec.resize(NumOperands); + unsigned NumLanes = VL.size(); + for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { + OpsVec[OpIdx].resize(NumLanes); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + assert(isa<Instruction>(VL[Lane]) && "Expected instruction"); + // Our tree has just 3 nodes: the root and two operands. + // It is therefore trivial to get the APO. We only need to check the + // opcode of VL[Lane] and whether the operand at OpIdx is the LHS or + // RHS operand. The LHS operand of both add and sub is never attached + // to an inversese operation in the linearized form, therefore its APO + // is false. The RHS is true only if VL[Lane] is an inverse operation. + + // Since operand reordering is performed on groups of commutative + // operations or alternating sequences (e.g., +, -), we can safely + // tell the inverse operations by checking commutativity. + bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane])); + bool APO = (OpIdx == 0) ? false : IsInverseOperation; + OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx), + APO, false}; + } + } + } + + /// \returns the number of operands. + unsigned getNumOperands() const { return OpsVec.size(); } + + /// \returns the number of lanes. + unsigned getNumLanes() const { return OpsVec[0].size(); } + + /// \returns the operand value at \p OpIdx and \p Lane. + Value *getValue(unsigned OpIdx, unsigned Lane) const { + return getData(OpIdx, Lane).V; + } + /// \returns true if the data structure is empty. + bool empty() const { return OpsVec.empty(); } + + /// Clears the data. + void clear() { OpsVec.clear(); } + + /// \Returns true if there are enough operands identical to \p Op to fill + /// the whole vector. + /// Note: This modifies the 'IsUsed' flag, so a cleanUsed() must follow. + bool shouldBroadcast(Value *Op, unsigned OpIdx, unsigned Lane) { + bool OpAPO = getData(OpIdx, Lane).APO; + for (unsigned Ln = 0, Lns = getNumLanes(); Ln != Lns; ++Ln) { + if (Ln == Lane) + continue; + // This is set to true if we found a candidate for broadcast at Lane. + bool FoundCandidate = false; + for (unsigned OpI = 0, OpE = getNumOperands(); OpI != OpE; ++OpI) { + OperandData &Data = getData(OpI, Ln); + if (Data.APO != OpAPO || Data.IsUsed) + continue; + if (Data.V == Op) { + FoundCandidate = true; + Data.IsUsed = true; + break; + } + } + if (!FoundCandidate) + return false; + } + return true; + } + + public: + /// Initialize with all the operands of the instruction vector \p RootVL. + VLOperands(ArrayRef<Value *> RootVL, const DataLayout &DL, + ScalarEvolution &SE) + : DL(DL), SE(SE) { + // Append all the operands of RootVL. + appendOperandsOfVL(RootVL); + } + + /// \Returns a value vector with the operands across all lanes for the + /// opearnd at \p OpIdx. + ValueList getVL(unsigned OpIdx) const { + ValueList OpVL(OpsVec[OpIdx].size()); + assert(OpsVec[OpIdx].size() == getNumLanes() && + "Expected same num of lanes across all operands"); + for (unsigned Lane = 0, Lanes = getNumLanes(); Lane != Lanes; ++Lane) + OpVL[Lane] = OpsVec[OpIdx][Lane].V; + return OpVL; + } + + // Performs operand reordering for 2 or more operands. + // The original operands are in OrigOps[OpIdx][Lane]. + // The reordered operands are returned in 'SortedOps[OpIdx][Lane]'. + void reorder() { + unsigned NumOperands = getNumOperands(); + unsigned NumLanes = getNumLanes(); + // Each operand has its own mode. We are using this mode to help us select + // the instructions for each lane, so that they match best with the ones + // we have selected so far. + SmallVector<ReorderingMode, 2> ReorderingModes(NumOperands); + + // This is a greedy single-pass algorithm. We are going over each lane + // once and deciding on the best order right away with no back-tracking. + // However, in order to increase its effectiveness, we start with the lane + // that has operands that can move the least. For example, given the + // following lanes: + // Lane 0 : A[0] = B[0] + C[0] // Visited 3rd + // Lane 1 : A[1] = C[1] - B[1] // Visited 1st + // Lane 2 : A[2] = B[2] + C[2] // Visited 2nd + // Lane 3 : A[3] = C[3] - B[3] // Visited 4th + // we will start at Lane 1, since the operands of the subtraction cannot + // be reordered. Then we will visit the rest of the lanes in a circular + // fashion. That is, Lanes 2, then Lane 0, and finally Lane 3. + + // Find the first lane that we will start our search from. + unsigned FirstLane = getBestLaneToStartReordering(); + + // Initialize the modes. + for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { + Value *OpLane0 = getValue(OpIdx, FirstLane); + // Keep track if we have instructions with all the same opcode on one + // side. + if (isa<LoadInst>(OpLane0)) + ReorderingModes[OpIdx] = ReorderingMode::Load; + else if (isa<Instruction>(OpLane0)) { + // Check if OpLane0 should be broadcast. + if (shouldBroadcast(OpLane0, OpIdx, FirstLane)) + ReorderingModes[OpIdx] = ReorderingMode::Splat; + else + ReorderingModes[OpIdx] = ReorderingMode::Opcode; + } + else if (isa<Constant>(OpLane0)) + ReorderingModes[OpIdx] = ReorderingMode::Constant; + else if (isa<Argument>(OpLane0)) + // Our best hope is a Splat. It may save some cost in some cases. + ReorderingModes[OpIdx] = ReorderingMode::Splat; + else + // NOTE: This should be unreachable. + ReorderingModes[OpIdx] = ReorderingMode::Failed; + } + + // If the initial strategy fails for any of the operand indexes, then we + // perform reordering again in a second pass. This helps avoid assigning + // high priority to the failed strategy, and should improve reordering for + // the non-failed operand indexes. + for (int Pass = 0; Pass != 2; ++Pass) { + // Skip the second pass if the first pass did not fail. + bool StrategyFailed = false; + // Mark all operand data as free to use. + clearUsed(); + // We keep the original operand order for the FirstLane, so reorder the + // rest of the lanes. We are visiting the nodes in a circular fashion, + // using FirstLane as the center point and increasing the radius + // distance. + for (unsigned Distance = 1; Distance != NumLanes; ++Distance) { + // Visit the lane on the right and then the lane on the left. + for (int Direction : {+1, -1}) { + int Lane = FirstLane + Direction * Distance; + if (Lane < 0 || Lane >= (int)NumLanes) + continue; + int LastLane = Lane - Direction; + assert(LastLane >= 0 && LastLane < (int)NumLanes && + "Out of bounds"); + // Look for a good match for each operand. + for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { + // Search for the operand that matches SortedOps[OpIdx][Lane-1]. + Optional<unsigned> BestIdx = + getBestOperand(OpIdx, Lane, LastLane, ReorderingModes); + // By not selecting a value, we allow the operands that follow to + // select a better matching value. We will get a non-null value in + // the next run of getBestOperand(). + if (BestIdx) { + // Swap the current operand with the one returned by + // getBestOperand(). + swap(OpIdx, BestIdx.getValue(), Lane); + } else { + // We failed to find a best operand, set mode to 'Failed'. + ReorderingModes[OpIdx] = ReorderingMode::Failed; + // Enable the second pass. + StrategyFailed = true; + } + } + } + } + // Skip second pass if the strategy did not fail. + if (!StrategyFailed) + break; + } + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD static StringRef getModeStr(ReorderingMode RMode) { + switch (RMode) { + case ReorderingMode::Load: + return "Load"; + case ReorderingMode::Opcode: + return "Opcode"; + case ReorderingMode::Constant: + return "Constant"; + case ReorderingMode::Splat: + return "Splat"; + case ReorderingMode::Failed: + return "Failed"; + } + llvm_unreachable("Unimplemented Reordering Type"); + } + + LLVM_DUMP_METHOD static raw_ostream &printMode(ReorderingMode RMode, + raw_ostream &OS) { + return OS << getModeStr(RMode); + } + + /// Debug print. + LLVM_DUMP_METHOD static void dumpMode(ReorderingMode RMode) { + printMode(RMode, dbgs()); + } + + friend raw_ostream &operator<<(raw_ostream &OS, ReorderingMode RMode) { + return printMode(RMode, OS); + } + + LLVM_DUMP_METHOD raw_ostream &print(raw_ostream &OS) const { + const unsigned Indent = 2; + unsigned Cnt = 0; + for (const OperandDataVec &OpDataVec : OpsVec) { + OS << "Operand " << Cnt++ << "\n"; + for (const OperandData &OpData : OpDataVec) { + OS.indent(Indent) << "{"; + if (Value *V = OpData.V) + OS << *V; + else + OS << "null"; + OS << ", APO:" << OpData.APO << "}\n"; + } + OS << "\n"; + } + return OS; + } + + /// Debug print. + LLVM_DUMP_METHOD void dump() const { print(dbgs()); } +#endif + }; + +private: /// Checks if all users of \p I are the part of the vectorization tree. bool areAllUsersVectorized(Instruction *I) const; @@ -613,7 +1125,8 @@ private: int getEntryCost(TreeEntry *E); /// This is the recursive part of buildTree. - void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int); + void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, + const EdgeInfo &EI); /// \returns true if the ExtractElement/ExtractValue instructions in \p VL can /// be vectorized to use the original vector (or aggregate "bitcast" to a @@ -631,12 +1144,12 @@ private: /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. - int getGatherCost(Type *Ty, const DenseSet<unsigned> &ShuffledIndices); + int getGatherCost(Type *Ty, const DenseSet<unsigned> &ShuffledIndices) const; /// \returns the scalarization cost for this list of values. Assuming that /// this subtree gets vectorized, we may need to extract the values from the /// roots. This method calculates the cost of extracting the values. - int getGatherCost(ArrayRef<Value *> VL); + int getGatherCost(ArrayRef<Value *> VL) const; /// Set the Builder insert point to one after the last instruction in /// the bundle @@ -648,22 +1161,18 @@ private: /// \returns whether the VectorizableTree is fully vectorizable and will /// be beneficial even the tree height is tiny. - bool isFullyVectorizableTinyTree(); + bool isFullyVectorizableTinyTree() const; - /// \reorder commutative operands in alt shuffle if they result in - /// vectorized code. - void reorderAltShuffleOperands(const InstructionsState &S, - ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right); - - /// \reorder commutative operands to get better probability of + /// Reorder commutative or alt operands to get better probability of /// generating vectorized code. - void reorderInputsAccordingToOpcode(unsigned Opcode, ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right); + static void reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, + SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, + const DataLayout &DL, + ScalarEvolution &SE); struct TreeEntry { - TreeEntry(std::vector<TreeEntry> &Container) : Container(Container) {} + using VecTreeTy = SmallVector<std::unique_ptr<TreeEntry>, 8>; + TreeEntry(VecTreeTy &Container) : Container(Container) {} /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { @@ -696,20 +1205,103 @@ private: /// to be a pointer and needs to be able to initialize the child iterator. /// Thus we need a reference back to the container to translate the indices /// to entries. - std::vector<TreeEntry> &Container; + VecTreeTy &Container; /// The TreeEntry index containing the user of this entry. We can actually /// have multiple users so the data structure is not truly a tree. - SmallVector<int, 1> UserTreeIndices; + SmallVector<EdgeInfo, 1> UserTreeIndices; + + /// The index of this treeEntry in VectorizableTree. + int Idx = -1; + + private: + /// The operands of each instruction in each lane Operands[op_index][lane]. + /// Note: This helps avoid the replication of the code that performs the + /// reordering of operands during buildTree_rec() and vectorizeTree(). + SmallVector<ValueList, 2> Operands; + + public: + /// Set this bundle's \p OpIdx'th operand to \p OpVL. + void setOperand(unsigned OpIdx, ArrayRef<Value *> OpVL, + ArrayRef<unsigned> ReuseShuffleIndices) { + if (Operands.size() < OpIdx + 1) + Operands.resize(OpIdx + 1); + assert(Operands[OpIdx].size() == 0 && "Already resized?"); + Operands[OpIdx].resize(Scalars.size()); + for (unsigned Lane = 0, E = Scalars.size(); Lane != E; ++Lane) + Operands[OpIdx][Lane] = (!ReuseShuffleIndices.empty()) + ? OpVL[ReuseShuffleIndices[Lane]] + : OpVL[Lane]; + } + + /// If there is a user TreeEntry, then set its operand. + void trySetUserTEOperand(const EdgeInfo &UserTreeIdx, + ArrayRef<Value *> OpVL, + ArrayRef<unsigned> ReuseShuffleIndices) { + if (UserTreeIdx.UserTE) + UserTreeIdx.UserTE->setOperand(UserTreeIdx.EdgeIdx, OpVL, + ReuseShuffleIndices); + } + + /// \returns the \p OpIdx operand of this TreeEntry. + ValueList &getOperand(unsigned OpIdx) { + assert(OpIdx < Operands.size() && "Off bounds"); + return Operands[OpIdx]; + } + + /// \return the single \p OpIdx operand. + Value *getSingleOperand(unsigned OpIdx) const { + assert(OpIdx < Operands.size() && "Off bounds"); + assert(!Operands[OpIdx].empty() && "No operand available"); + return Operands[OpIdx][0]; + } + +#ifndef NDEBUG + /// Debug printer. + LLVM_DUMP_METHOD void dump() const { + dbgs() << Idx << ".\n"; + for (unsigned OpI = 0, OpE = Operands.size(); OpI != OpE; ++OpI) { + dbgs() << "Operand " << OpI << ":\n"; + for (const Value *V : Operands[OpI]) + dbgs().indent(2) << *V << "\n"; + } + dbgs() << "Scalars: \n"; + for (Value *V : Scalars) + dbgs().indent(2) << *V << "\n"; + dbgs() << "NeedToGather: " << NeedToGather << "\n"; + dbgs() << "VectorizedValue: "; + if (VectorizedValue) + dbgs() << *VectorizedValue; + else + dbgs() << "NULL"; + dbgs() << "\n"; + dbgs() << "ReuseShuffleIndices: "; + if (ReuseShuffleIndices.empty()) + dbgs() << "Emtpy"; + else + for (unsigned Idx : ReuseShuffleIndices) + dbgs() << Idx << ", "; + dbgs() << "\n"; + dbgs() << "ReorderIndices: "; + for (unsigned Idx : ReorderIndices) + dbgs() << Idx << ", "; + dbgs() << "\n"; + dbgs() << "UserTreeIndices: "; + for (const auto &EInfo : UserTreeIndices) + dbgs() << EInfo << ", "; + dbgs() << "\n"; + } +#endif }; /// Create a new VectorizableTree entry. - void newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, int &UserTreeIdx, - ArrayRef<unsigned> ReuseShuffleIndices = None, - ArrayRef<unsigned> ReorderIndices = None) { - VectorizableTree.emplace_back(VectorizableTree); - int idx = VectorizableTree.size() - 1; - TreeEntry *Last = &VectorizableTree[idx]; + TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, + const EdgeInfo &UserTreeIdx, + ArrayRef<unsigned> ReuseShuffleIndices = None, + ArrayRef<unsigned> ReorderIndices = None) { + VectorizableTree.push_back(llvm::make_unique<TreeEntry>(VectorizableTree)); + TreeEntry *Last = VectorizableTree.back().get(); + Last->Idx = VectorizableTree.size() - 1; Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->NeedToGather = !Vectorized; Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), @@ -718,25 +1310,44 @@ private: if (Vectorized) { for (int i = 0, e = VL.size(); i != e; ++i) { assert(!getTreeEntry(VL[i]) && "Scalar already in tree!"); - ScalarToTreeEntry[VL[i]] = idx; + ScalarToTreeEntry[VL[i]] = Last->Idx; } } else { MustGather.insert(VL.begin(), VL.end()); } - if (UserTreeIdx >= 0) + if (UserTreeIdx.UserTE) Last->UserTreeIndices.push_back(UserTreeIdx); - UserTreeIdx = idx; + + Last->trySetUserTEOperand(UserTreeIdx, VL, ReuseShuffleIndices); + return Last; } /// -- Vectorization State -- /// Holds all of the tree entries. - std::vector<TreeEntry> VectorizableTree; + TreeEntry::VecTreeTy VectorizableTree; + +#ifndef NDEBUG + /// Debug printer. + LLVM_DUMP_METHOD void dumpVectorizableTree() const { + for (unsigned Id = 0, IdE = VectorizableTree.size(); Id != IdE; ++Id) { + VectorizableTree[Id]->dump(); + dbgs() << "\n"; + } + } +#endif TreeEntry *getTreeEntry(Value *V) { auto I = ScalarToTreeEntry.find(V); if (I != ScalarToTreeEntry.end()) - return &VectorizableTree[I->second]; + return VectorizableTree[I->second].get(); + return nullptr; + } + + const TreeEntry *getTreeEntry(Value *V) const { + auto I = ScalarToTreeEntry.find(V); + if (I != ScalarToTreeEntry.end()) + return VectorizableTree[I->second].get(); return nullptr; } @@ -1246,21 +1857,25 @@ template <> struct GraphTraits<BoUpSLP *> { /// NodeRef has to be a pointer per the GraphWriter. using NodeRef = TreeEntry *; + using ContainerTy = BoUpSLP::TreeEntry::VecTreeTy; + /// Add the VectorizableTree to the index iterator to be able to return /// TreeEntry pointers. struct ChildIteratorType - : public iterator_adaptor_base<ChildIteratorType, - SmallVector<int, 1>::iterator> { - std::vector<TreeEntry> &VectorizableTree; + : public iterator_adaptor_base< + ChildIteratorType, SmallVector<BoUpSLP::EdgeInfo, 1>::iterator> { + ContainerTy &VectorizableTree; - ChildIteratorType(SmallVector<int, 1>::iterator W, - std::vector<TreeEntry> &VT) + ChildIteratorType(SmallVector<BoUpSLP::EdgeInfo, 1>::iterator W, + ContainerTy &VT) : ChildIteratorType::iterator_adaptor_base(W), VectorizableTree(VT) {} - NodeRef operator*() { return &VectorizableTree[*I]; } + NodeRef operator*() { return I->UserTE; } }; - static NodeRef getEntryNode(BoUpSLP &R) { return &R.VectorizableTree[0]; } + static NodeRef getEntryNode(BoUpSLP &R) { + return R.VectorizableTree[0].get(); + } static ChildIteratorType child_begin(NodeRef N) { return {N->UserTreeIndices.begin(), N->Container}; @@ -1272,7 +1887,19 @@ template <> struct GraphTraits<BoUpSLP *> { /// For the node iterator we just need to turn the TreeEntry iterator into a /// TreeEntry* iterator so that it dereferences to NodeRef. - using nodes_iterator = pointer_iterator<std::vector<TreeEntry>::iterator>; + class nodes_iterator { + using ItTy = ContainerTy::iterator; + ItTy It; + + public: + nodes_iterator(const ItTy &It2) : It(It2) {} + NodeRef operator*() { return It->get(); } + nodes_iterator operator++() { + ++It; + return *this; + } + bool operator!=(const nodes_iterator &N2) const { return N2.It != It; } + }; static nodes_iterator nodes_begin(BoUpSLP *R) { return nodes_iterator(R->VectorizableTree.begin()); @@ -1331,11 +1958,11 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, UserIgnoreList = UserIgnoreLst; if (!allSameType(Roots)) return; - buildTree_rec(Roots, 0, -1); + buildTree_rec(Roots, 0, EdgeInfo()); // Collect the values that we need to extract from the tree. - for (TreeEntry &EIdx : VectorizableTree) { - TreeEntry *Entry = &EIdx; + for (auto &TEPtr : VectorizableTree) { + TreeEntry *Entry = TEPtr.get(); // No need to handle users of gathered values. if (Entry->NeedToGather) @@ -1393,7 +2020,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, } void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, - int UserTreeIdx) { + const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); InstructionsState S = getSameOpcode(VL); @@ -1450,6 +2077,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, E->UserTreeIndices.push_back(UserTreeIdx); LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue << ".\n"); + E->trySetUserTEOperand(UserTreeIdx, VL, None); return; } @@ -1468,8 +2096,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // If any of the scalars is marked as a value that needs to stay scalar, then // we need to gather the scalars. + // The reduction nodes (stored in UserIgnoreList) also should stay scalar. for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (MustGather.count(VL[i])) { + if (MustGather.count(VL[i]) || is_contained(UserIgnoreList, VL[i])) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); newTreeEntry(VL, false, UserTreeIdx); return; @@ -1548,7 +2177,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { @@ -1558,7 +2187,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock( PH->getIncomingBlock(i))); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; } @@ -1571,6 +2200,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, ++NumOpsWantToKeepOriginalOrder; newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, ReuseShuffleIndicies); + // This is a special case, as it does not gather, but at the same time + // we are not extending buildTree_rec() towards the operands. + ValueList Op0; + Op0.assign(VL.size(), VL0->getOperand(0)); + VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies); return; } if (!CurrentOrder.empty()) { @@ -1588,6 +2222,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, ++StoredCurrentOrderAndNum->getSecond(); newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, ReuseShuffleIndicies, StoredCurrentOrderAndNum->getFirst()); + // This is a special case, as it does not gather, but at the same time + // we are not extending buildTree_rec() towards the operands. + ValueList Op0; + Op0.assign(VL.size(), VL0->getOperand(0)); + VectorizableTree.back()->setOperand(0, Op0, ReuseShuffleIndicies); return; } LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); @@ -1693,7 +2332,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1702,7 +2341,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; } @@ -1710,10 +2349,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::FCmp: { // Check that all of the compares have the same predicate. CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); + CmpInst::Predicate SwapP0 = CmpInst::getSwappedPredicate(P0); Type *ComparedTy = VL0->getOperand(0)->getType(); for (unsigned i = 1, e = VL.size(); i < e; ++i) { CmpInst *Cmp = cast<CmpInst>(VL[i]); - if (Cmp->getPredicate() != P0 || + if ((Cmp->getPredicate() != P0 && Cmp->getPredicate() != SwapP0) || Cmp->getOperand(0)->getType() != ComparedTy) { BS.cancelScheduling(VL, VL0); newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); @@ -1723,20 +2363,34 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); - for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { - ValueList Operands; - // Prepare the operand vector. - for (Value *j : VL) - Operands.push_back(cast<Instruction>(j)->getOperand(i)); - - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + ValueList Left, Right; + if (cast<CmpInst>(VL0)->isCommutative()) { + // Commutative predicate - collect + sort operands of the instructions + // so that each side is more likely to have the same opcode. + assert(P0 == SwapP0 && "Commutative Predicate mismatch"); + reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE); + } else { + // Collect operands - commute if it uses the swapped predicate. + for (Value *V : VL) { + auto *Cmp = cast<CmpInst>(V); + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); + if (Cmp->getPredicate() != P0) + std::swap(LHS, RHS); + Left.push_back(LHS); + Right.push_back(RHS); + } } + + buildTree_rec(Left, Depth + 1, {TE, 0}); + buildTree_rec(Right, Depth + 1, {TE, 1}); return; } case Instruction::Select: + case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -1754,17 +2408,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::AShr: case Instruction::And: case Instruction::Or: - case Instruction::Xor: - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: added a vector of bin op.\n"); + case Instruction::Xor: { + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of un/bin op.\n"); // Sort operands of the instructions so that each side is more likely to // have the same opcode. if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; - reorderInputsAccordingToOpcode(S.getOpcode(), VL, Left, Right); - buildTree_rec(Left, Depth + 1, UserTreeIdx); - buildTree_rec(Right, Depth + 1, UserTreeIdx); + reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE); + buildTree_rec(Left, Depth + 1, {TE, 0}); + buildTree_rec(Right, Depth + 1, {TE, 1}); return; } @@ -1774,10 +2428,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; - + } case Instruction::GetElementPtr: { // We don't combine GEPs with complicated (nested) indexing. for (unsigned j = 0; j < VL.size(); ++j) { @@ -1815,7 +2469,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; @@ -1823,7 +2477,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; } @@ -1837,14 +2491,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(0)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, 0}); return; } case Instruction::Call: { @@ -1860,9 +2514,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } Function *Int = CI->getCalledFunction(); - Value *A1I = nullptr; - if (hasVectorInstrinsicScalarOpd(ID, 1)) - A1I = CI->getArgOperand(1); + unsigned NumArgs = CI->getNumArgOperands(); + SmallVector<Value*, 4> ScalarArgs(NumArgs, nullptr); + for (unsigned j = 0; j != NumArgs; ++j) + if (hasVectorInstrinsicScalarOpd(ID, j)) + ScalarArgs[j] = CI->getArgOperand(j); for (unsigned i = 1, e = VL.size(); i != e; ++i) { CallInst *CI2 = dyn_cast<CallInst>(VL[i]); if (!CI2 || CI2->getCalledFunction() != Int || @@ -1874,16 +2530,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, << "\n"); return; } - // ctlz,cttz and powi are special intrinsics whose second argument - // should be same in order for them to be vectorized. - if (hasVectorInstrinsicScalarOpd(ID, 1)) { - Value *A1J = CI2->getArgOperand(1); - if (A1I != A1J) { - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); - LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI - << " argument " << A1I << "!=" << A1J << "\n"); - return; + // Some intrinsics have scalar arguments and should be same in order for + // them to be vectorized. + for (unsigned j = 0; j != NumArgs; ++j) { + if (hasVectorInstrinsicScalarOpd(ID, j)) { + Value *A1J = CI2->getArgOperand(j); + if (ScalarArgs[j] != A1J) { + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI + << " argument " << ScalarArgs[j] << "!=" << A1J + << "\n"); + return; + } } } // Verify that the bundle operands are identical between the two calls. @@ -1899,7 +2558,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -1907,11 +2566,11 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CallInst *CI2 = dyn_cast<CallInst>(j); Operands.push_back(CI2->getArgOperand(i)); } - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; } - case Instruction::ShuffleVector: + case Instruction::ShuffleVector: { // If this is not an alternate sequence of opcode like add-sub // then do not vectorize this instruction. if (!S.isAltShuffle()) { @@ -1920,15 +2579,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; } - newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + auto *TE = newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. if (isa<BinaryOperator>(VL0)) { ValueList Left, Right; - reorderAltShuffleOperands(S, VL, Left, Right); - buildTree_rec(Left, Depth + 1, UserTreeIdx); - buildTree_rec(Right, Depth + 1, UserTreeIdx); + reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE); + buildTree_rec(Left, Depth + 1, {TE, 0}); + buildTree_rec(Right, Depth + 1, {TE, 1}); return; } @@ -1938,10 +2597,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, {TE, i}); } return; - + } default: BS.cancelScheduling(VL, VL0); newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); @@ -2223,6 +2882,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { int VecCost = TTI->getCmpSelInstrCost(S.getOpcode(), VecTy, MaskTy, VL0); return ReuseShuffleCost + VecCost - ScalarCost; } + case Instruction::FNeg: case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -2260,7 +2920,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { ConstantInt *CInt0 = nullptr; for (unsigned i = 0, e = VL.size(); i < e; ++i) { const Instruction *I = cast<Instruction>(VL[i]); - ConstantInt *CInt = dyn_cast<ConstantInt>(I->getOperand(1)); + unsigned OpIdx = isa<BinaryOperator>(I) ? 1 : 0; + ConstantInt *CInt = dyn_cast<ConstantInt>(I->getOperand(OpIdx)); if (!CInt) { Op2VK = TargetTransformInfo::OK_AnyValue; Op2VP = TargetTransformInfo::OP_None; @@ -2413,31 +3074,31 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } } -bool BoUpSLP::isFullyVectorizableTinyTree() { +bool BoUpSLP::isFullyVectorizableTinyTree() const { LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height " << VectorizableTree.size() << " is fully vectorizable .\n"); // We only handle trees of heights 1 and 2. - if (VectorizableTree.size() == 1 && !VectorizableTree[0].NeedToGather) + if (VectorizableTree.size() == 1 && !VectorizableTree[0]->NeedToGather) return true; if (VectorizableTree.size() != 2) return false; // Handle splat and all-constants stores. - if (!VectorizableTree[0].NeedToGather && - (allConstant(VectorizableTree[1].Scalars) || - isSplat(VectorizableTree[1].Scalars))) + if (!VectorizableTree[0]->NeedToGather && + (allConstant(VectorizableTree[1]->Scalars) || + isSplat(VectorizableTree[1]->Scalars))) return true; // Gathering cost would be too much for tiny trees. - if (VectorizableTree[0].NeedToGather || VectorizableTree[1].NeedToGather) + if (VectorizableTree[0]->NeedToGather || VectorizableTree[1]->NeedToGather) return false; return true; } -bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() { +bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const { // We can vectorize the tree if its size is greater than or equal to the // minimum size specified by the MinTreeSize command line option. if (VectorizableTree.size() >= MinTreeSize) @@ -2457,19 +3118,19 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() { return true; } -int BoUpSLP::getSpillCost() { +int BoUpSLP::getSpillCost() const { // Walk from the bottom of the tree to the top, tracking which values are // live. When we see a call instruction that is not part of our tree, // query TTI to see if there is a cost to keeping values live over it // (for example, if spills and fills are required). - unsigned BundleWidth = VectorizableTree.front().Scalars.size(); + unsigned BundleWidth = VectorizableTree.front()->Scalars.size(); int Cost = 0; SmallPtrSet<Instruction*, 4> LiveValues; Instruction *PrevInst = nullptr; - for (const auto &N : VectorizableTree) { - Instruction *Inst = dyn_cast<Instruction>(N.Scalars[0]); + for (const auto &TEPtr : VectorizableTree) { + Instruction *Inst = dyn_cast<Instruction>(TEPtr->Scalars[0]); if (!Inst) continue; @@ -2494,6 +3155,7 @@ int BoUpSLP::getSpillCost() { }); // Now find the sequence of instructions between PrevInst and Inst. + unsigned NumCalls = 0; BasicBlock::reverse_iterator InstIt = ++Inst->getIterator().getReverse(), PrevInstIt = PrevInst->getIterator().getReverse(); @@ -2506,16 +3168,19 @@ int BoUpSLP::getSpillCost() { // Debug informations don't impact spill cost. if ((isa<CallInst>(&*PrevInstIt) && !isa<DbgInfoIntrinsic>(&*PrevInstIt)) && - &*PrevInstIt != PrevInst) { - SmallVector<Type*, 4> V; - for (auto *II : LiveValues) - V.push_back(VectorType::get(II->getType(), BundleWidth)); - Cost += TTI->getCostOfKeepingLiveOverCall(V); - } + &*PrevInstIt != PrevInst) + NumCalls++; ++PrevInstIt; } + if (NumCalls) { + SmallVector<Type*, 4> V; + for (auto *II : LiveValues) + V.push_back(VectorType::get(II->getType(), BundleWidth)); + Cost += NumCalls * TTI->getCostOfKeepingLiveOverCall(V); + } + PrevInst = Inst; } @@ -2527,10 +3192,10 @@ int BoUpSLP::getTreeCost() { LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); - unsigned BundleWidth = VectorizableTree[0].Scalars.size(); + unsigned BundleWidth = VectorizableTree[0]->Scalars.size(); for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { - TreeEntry &TE = VectorizableTree[I]; + TreeEntry &TE = *VectorizableTree[I].get(); // We create duplicate tree entries for gather sequences that have multiple // uses. However, we should not compute the cost of duplicate sequences. @@ -2545,10 +3210,11 @@ int BoUpSLP::getTreeCost() { // existing heuristics based on tree size may yield different results. // if (TE.NeedToGather && - std::any_of(std::next(VectorizableTree.begin(), I + 1), - VectorizableTree.end(), [TE](TreeEntry &Entry) { - return Entry.NeedToGather && Entry.isSame(TE.Scalars); - })) + std::any_of( + std::next(VectorizableTree.begin(), I + 1), VectorizableTree.end(), + [TE](const std::unique_ptr<TreeEntry> &EntryPtr) { + return EntryPtr->NeedToGather && EntryPtr->isSame(TE.Scalars); + })) continue; int C = getEntryCost(&TE); @@ -2575,7 +3241,7 @@ int BoUpSLP::getTreeCost() { // extend the extracted value back to the original type. Here, we account // for the extract and the added cost of the sign extend if needed. auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); - auto *ScalarRoot = VectorizableTree[0].Scalars[0]; + auto *ScalarRoot = VectorizableTree[0]->Scalars[0]; if (MinBWs.count(ScalarRoot)) { auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); auto Extend = @@ -2608,17 +3274,17 @@ int BoUpSLP::getTreeCost() { } int BoUpSLP::getGatherCost(Type *Ty, - const DenseSet<unsigned> &ShuffledIndices) { + const DenseSet<unsigned> &ShuffledIndices) const { int Cost = 0; for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i) if (!ShuffledIndices.count(i)) Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i); if (!ShuffledIndices.empty()) - Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); return Cost; } -int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { +int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { // Find the type of the operands in VL. Type *ScalarTy = VL[0]->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) @@ -2638,221 +3304,19 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { return getGatherCost(VecTy, ShuffledElements); } -// Reorder commutative operations in alternate shuffle if the resulting vectors -// are consecutive loads. This would allow us to vectorize the tree. -// If we have something like- -// load a[0] - load b[0] -// load b[1] + load a[1] -// load a[2] - load b[2] -// load a[3] + load b[3] -// Reordering the second load b[1] load a[1] would allow us to vectorize this -// code. -void BoUpSLP::reorderAltShuffleOperands(const InstructionsState &S, - ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right) { - // Push left and right operands of binary operation into Left and Right - for (Value *V : VL) { - auto *I = cast<Instruction>(V); - assert(S.isOpcodeOrAlt(I) && "Incorrect instruction in vector"); - Left.push_back(I->getOperand(0)); - Right.push_back(I->getOperand(1)); - } - - // Reorder if we have a commutative operation and consecutive access - // are on either side of the alternate instructions. - for (unsigned j = 0; j < VL.size() - 1; ++j) { - if (LoadInst *L = dyn_cast<LoadInst>(Left[j])) { - if (LoadInst *L1 = dyn_cast<LoadInst>(Right[j + 1])) { - Instruction *VL1 = cast<Instruction>(VL[j]); - Instruction *VL2 = cast<Instruction>(VL[j + 1]); - if (VL1->isCommutative() && isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j], Right[j]); - continue; - } else if (VL2->isCommutative() && - isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j + 1], Right[j + 1]); - continue; - } - // else unchanged - } - } - if (LoadInst *L = dyn_cast<LoadInst>(Right[j])) { - if (LoadInst *L1 = dyn_cast<LoadInst>(Left[j + 1])) { - Instruction *VL1 = cast<Instruction>(VL[j]); - Instruction *VL2 = cast<Instruction>(VL[j + 1]); - if (VL1->isCommutative() && isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j], Right[j]); - continue; - } else if (VL2->isCommutative() && - isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j + 1], Right[j + 1]); - continue; - } - // else unchanged - } - } - } -} - -// Return true if I should be commuted before adding it's left and right -// operands to the arrays Left and Right. -// -// The vectorizer is trying to either have all elements one side being -// instruction with the same opcode to enable further vectorization, or having -// a splat to lower the vectorizing cost. -static bool shouldReorderOperands( - int i, unsigned Opcode, Instruction &I, ArrayRef<Value *> Left, - ArrayRef<Value *> Right, bool AllSameOpcodeLeft, bool AllSameOpcodeRight, - bool SplatLeft, bool SplatRight, Value *&VLeft, Value *&VRight) { - VLeft = I.getOperand(0); - VRight = I.getOperand(1); - // If we have "SplatRight", try to see if commuting is needed to preserve it. - if (SplatRight) { - if (VRight == Right[i - 1]) - // Preserve SplatRight - return false; - if (VLeft == Right[i - 1]) { - // Commuting would preserve SplatRight, but we don't want to break - // SplatLeft either, i.e. preserve the original order if possible. - // (FIXME: why do we care?) - if (SplatLeft && VLeft == Left[i - 1]) - return false; - return true; - } - } - // Symmetrically handle Right side. - if (SplatLeft) { - if (VLeft == Left[i - 1]) - // Preserve SplatLeft - return false; - if (VRight == Left[i - 1]) - return true; - } - - Instruction *ILeft = dyn_cast<Instruction>(VLeft); - Instruction *IRight = dyn_cast<Instruction>(VRight); - - // If we have "AllSameOpcodeRight", try to see if the left operands preserves - // it and not the right, in this case we want to commute. - if (AllSameOpcodeRight) { - unsigned RightPrevOpcode = cast<Instruction>(Right[i - 1])->getOpcode(); - if (IRight && RightPrevOpcode == IRight->getOpcode()) - // Do not commute, a match on the right preserves AllSameOpcodeRight - return false; - if (ILeft && RightPrevOpcode == ILeft->getOpcode()) { - // We have a match and may want to commute, but first check if there is - // not also a match on the existing operands on the Left to preserve - // AllSameOpcodeLeft, i.e. preserve the original order if possible. - // (FIXME: why do we care?) - if (AllSameOpcodeLeft && ILeft && - cast<Instruction>(Left[i - 1])->getOpcode() == ILeft->getOpcode()) - return false; - return true; - } - } - // Symmetrically handle Left side. - if (AllSameOpcodeLeft) { - unsigned LeftPrevOpcode = cast<Instruction>(Left[i - 1])->getOpcode(); - if (ILeft && LeftPrevOpcode == ILeft->getOpcode()) - return false; - if (IRight && LeftPrevOpcode == IRight->getOpcode()) - return true; - } - return false; -} - -void BoUpSLP::reorderInputsAccordingToOpcode(unsigned Opcode, - ArrayRef<Value *> VL, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right) { - if (!VL.empty()) { - // Peel the first iteration out of the loop since there's nothing - // interesting to do anyway and it simplifies the checks in the loop. - auto *I = cast<Instruction>(VL[0]); - Value *VLeft = I->getOperand(0); - Value *VRight = I->getOperand(1); - if (!isa<Instruction>(VRight) && isa<Instruction>(VLeft)) - // Favor having instruction to the right. FIXME: why? - std::swap(VLeft, VRight); - Left.push_back(VLeft); - Right.push_back(VRight); - } - - // Keep track if we have instructions with all the same opcode on one side. - bool AllSameOpcodeLeft = isa<Instruction>(Left[0]); - bool AllSameOpcodeRight = isa<Instruction>(Right[0]); - // Keep track if we have one side with all the same value (broadcast). - bool SplatLeft = true; - bool SplatRight = true; - - for (unsigned i = 1, e = VL.size(); i != e; ++i) { - Instruction *I = cast<Instruction>(VL[i]); - assert(((I->getOpcode() == Opcode && I->isCommutative()) || - (I->getOpcode() != Opcode && Instruction::isCommutative(Opcode))) && - "Can only process commutative instruction"); - // Commute to favor either a splat or maximizing having the same opcodes on - // one side. - Value *VLeft; - Value *VRight; - if (shouldReorderOperands(i, Opcode, *I, Left, Right, AllSameOpcodeLeft, - AllSameOpcodeRight, SplatLeft, SplatRight, VLeft, - VRight)) { - Left.push_back(VRight); - Right.push_back(VLeft); - } else { - Left.push_back(VLeft); - Right.push_back(VRight); - } - // Update Splat* and AllSameOpcode* after the insertion. - SplatRight = SplatRight && (Right[i - 1] == Right[i]); - SplatLeft = SplatLeft && (Left[i - 1] == Left[i]); - AllSameOpcodeLeft = AllSameOpcodeLeft && isa<Instruction>(Left[i]) && - (cast<Instruction>(Left[i - 1])->getOpcode() == - cast<Instruction>(Left[i])->getOpcode()); - AllSameOpcodeRight = AllSameOpcodeRight && isa<Instruction>(Right[i]) && - (cast<Instruction>(Right[i - 1])->getOpcode() == - cast<Instruction>(Right[i])->getOpcode()); - } - - // If one operand end up being broadcast, return this operand order. - if (SplatRight || SplatLeft) +// Perform operand reordering on the instructions in VL and return the reordered +// operands in Left and Right. +void BoUpSLP::reorderInputsAccordingToOpcode( + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, const DataLayout &DL, + ScalarEvolution &SE) { + if (VL.empty()) return; - - // Finally check if we can get longer vectorizable chain by reordering - // without breaking the good operand order detected above. - // E.g. If we have something like- - // load a[0] load b[0] - // load b[1] load a[1] - // load a[2] load b[2] - // load a[3] load b[3] - // Reordering the second load b[1] load a[1] would allow us to vectorize - // this code and we still retain AllSameOpcode property. - // FIXME: This load reordering might break AllSameOpcode in some rare cases - // such as- - // add a[0],c[0] load b[0] - // add a[1],c[2] load b[1] - // b[2] load b[2] - // add a[3],c[3] load b[3] - for (unsigned j = 0, e = VL.size() - 1; j < e; ++j) { - if (LoadInst *L = dyn_cast<LoadInst>(Left[j])) { - if (LoadInst *L1 = dyn_cast<LoadInst>(Right[j + 1])) { - if (isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j + 1], Right[j + 1]); - continue; - } - } - } - if (LoadInst *L = dyn_cast<LoadInst>(Right[j])) { - if (LoadInst *L1 = dyn_cast<LoadInst>(Left[j + 1])) { - if (isConsecutiveAccess(L, L1, *DL, *SE)) { - std::swap(Left[j + 1], Right[j + 1]); - continue; - } - } - } - // else unchanged - } + VLOperands Ops(VL, DL, SE); + // Reorder the operands in place. + Ops.reorder(); + Left = Ops.getVL(0); + Right = Ops.getVL(1); } void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, @@ -3082,13 +3546,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { continue; } - // Prepare the operand vector. - for (Value *V : E->Scalars) - Operands.push_back(cast<PHINode>(V)->getIncomingValueForBlock(IBB)); - Builder.SetInsertPoint(IBB->getTerminator()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - Value *Vec = vectorizeTree(Operands); + Value *Vec = vectorizeTree(E->getOperand(i)); NewPhi->addIncoming(Vec, IBB); } @@ -3099,7 +3559,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::ExtractElement: { if (!E->NeedToGather) { - Value *V = VL0->getOperand(0); + Value *V = E->getSingleOperand(0); if (!E->ReorderIndices.empty()) { OrdersType Mask; inversePermutation(E->ReorderIndices, Mask); @@ -3132,11 +3592,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::ExtractValue: { if (!E->NeedToGather) { - LoadInst *LI = cast<LoadInst>(VL0->getOperand(0)); + LoadInst *LI = cast<LoadInst>(E->getSingleOperand(0)); Builder.SetInsertPoint(LI); PointerType *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); - LoadInst *V = Builder.CreateAlignedLoad(Ptr, LI->getAlignment()); + LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlignment()); Value *NewV = propagateMetadata(V, E->Scalars); if (!E->ReorderIndices.empty()) { OrdersType Mask; @@ -3177,13 +3637,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { - ValueList INVL; - for (Value *V : E->Scalars) - INVL.push_back(cast<Instruction>(V)->getOperand(0)); - setInsertPointAfterBundle(E->Scalars, S); - Value *InVec = vectorizeTree(INVL); + Value *InVec = vectorizeTree(E->getOperand(0)); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); @@ -3202,16 +3658,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::FCmp: case Instruction::ICmp: { - ValueList LHSV, RHSV; - for (Value *V : E->Scalars) { - LHSV.push_back(cast<Instruction>(V)->getOperand(0)); - RHSV.push_back(cast<Instruction>(V)->getOperand(1)); - } - setInsertPointAfterBundle(E->Scalars, S); - Value *L = vectorizeTree(LHSV); - Value *R = vectorizeTree(RHSV); + Value *L = vectorizeTree(E->getOperand(0)); + Value *R = vectorizeTree(E->getOperand(1)); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); @@ -3235,31 +3685,49 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::Select: { - ValueList TrueVec, FalseVec, CondVec; - for (Value *V : E->Scalars) { - CondVec.push_back(cast<Instruction>(V)->getOperand(0)); - TrueVec.push_back(cast<Instruction>(V)->getOperand(1)); - FalseVec.push_back(cast<Instruction>(V)->getOperand(2)); + setInsertPointAfterBundle(E->Scalars, S); + + Value *Cond = vectorizeTree(E->getOperand(0)); + Value *True = vectorizeTree(E->getOperand(1)); + Value *False = vectorizeTree(E->getOperand(2)); + + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; } + Value *V = Builder.CreateSelect(Cond, True, False); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } + E->VectorizedValue = V; + ++NumVectorInstructions; + return V; + } + case Instruction::FNeg: { setInsertPointAfterBundle(E->Scalars, S); - Value *Cond = vectorizeTree(CondVec); - Value *True = vectorizeTree(TrueVec); - Value *False = vectorizeTree(FalseVec); + Value *Op = vectorizeTree(E->getOperand(0)); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - Value *V = Builder.CreateSelect(Cond, True, False); + Value *V = Builder.CreateUnOp( + static_cast<Instruction::UnaryOps>(S.getOpcode()), Op); + propagateIRFlags(V, E->Scalars, VL0); + if (auto *I = dyn_cast<Instruction>(V)) + V = propagateMetadata(I, E->Scalars); + if (NeedToShuffleReuses) { V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), E->ReuseShuffleIndices, "shuffle"); } E->VectorizedValue = V; ++NumVectorInstructions; + return V; } case Instruction::Add: @@ -3280,21 +3748,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { - ValueList LHSVL, RHSVL; - if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) - reorderInputsAccordingToOpcode(S.getOpcode(), E->Scalars, LHSVL, - RHSVL); - else - for (Value *V : E->Scalars) { - auto *I = cast<Instruction>(V); - LHSVL.push_back(I->getOperand(0)); - RHSVL.push_back(I->getOperand(1)); - } - setInsertPointAfterBundle(E->Scalars, S); - Value *LHS = vectorizeTree(LHSVL); - Value *RHS = vectorizeTree(RHSVL); + Value *LHS = vectorizeTree(E->getOperand(0)); + Value *RHS = vectorizeTree(E->getOperand(1)); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); @@ -3341,7 +3798,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0)); unsigned Alignment = LI->getAlignment(); - LI = Builder.CreateLoad(VecPtr); + LI = Builder.CreateLoad(VecTy, VecPtr); if (!Alignment) { Alignment = DL->getABITypeAlignment(ScalarLoadTy); } @@ -3367,13 +3824,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { unsigned Alignment = SI->getAlignment(); unsigned AS = SI->getPointerAddressSpace(); - ValueList ScalarStoreValues; - for (Value *V : E->Scalars) - ScalarStoreValues.push_back(cast<StoreInst>(V)->getValueOperand()); - setInsertPointAfterBundle(E->Scalars, S); - Value *VecValue = vectorizeTree(ScalarStoreValues); + Value *VecValue = vectorizeTree(E->getOperand(0)); Value *ScalarPtr = SI->getPointerOperand(); Value *VecPtr = Builder.CreateBitCast(ScalarPtr, VecTy->getPointerTo(AS)); StoreInst *ST = Builder.CreateStore(VecValue, VecPtr); @@ -3400,20 +3853,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::GetElementPtr: { setInsertPointAfterBundle(E->Scalars, S); - ValueList Op0VL; - for (Value *V : E->Scalars) - Op0VL.push_back(cast<GetElementPtrInst>(V)->getOperand(0)); - - Value *Op0 = vectorizeTree(Op0VL); + Value *Op0 = vectorizeTree(E->getOperand(0)); std::vector<Value *> OpVecs; for (int j = 1, e = cast<GetElementPtrInst>(VL0)->getNumOperands(); j < e; ++j) { - ValueList OpVL; - for (Value *V : E->Scalars) - OpVL.push_back(cast<GetElementPtrInst>(V)->getOperand(j)); - - Value *OpVec = vectorizeTree(OpVL); + Value *OpVec = vectorizeTree(E->getOperand(j)); OpVecs.push_back(OpVec); } @@ -3443,20 +3888,16 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { std::vector<Value *> OpVecs; for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) { ValueList OpVL; - // ctlz,cttz and powi are special intrinsics whose second argument is - // a scalar. This argument should not be vectorized. - if (hasVectorInstrinsicScalarOpd(IID, 1) && j == 1) { + // Some intrinsics have scalar arguments. This argument should not be + // vectorized. + if (hasVectorInstrinsicScalarOpd(IID, j)) { CallInst *CEI = cast<CallInst>(VL0); ScalarArg = CEI->getArgOperand(j); OpVecs.push_back(CEI->getArgOperand(j)); continue; } - for (Value *V : E->Scalars) { - CallInst *CEI = cast<CallInst>(V); - OpVL.push_back(CEI->getArgOperand(j)); - } - Value *OpVec = vectorizeTree(OpVL); + Value *OpVec = vectorizeTree(E->getOperand(j)); LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); } @@ -3485,7 +3926,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return V; } case Instruction::ShuffleVector: { - ValueList LHSVL, RHSVL; assert(S.isAltShuffle() && ((Instruction::isBinaryOp(S.getOpcode()) && Instruction::isBinaryOp(S.getAltOpcode())) || @@ -3495,16 +3935,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS, *RHS; if (Instruction::isBinaryOp(S.getOpcode())) { - reorderAltShuffleOperands(S, E->Scalars, LHSVL, RHSVL); setInsertPointAfterBundle(E->Scalars, S); - LHS = vectorizeTree(LHSVL); - RHS = vectorizeTree(RHSVL); + LHS = vectorizeTree(E->getOperand(0)); + RHS = vectorizeTree(E->getOperand(1)); } else { - ValueList INVL; - for (Value *V : E->Scalars) - INVL.push_back(cast<Instruction>(V)->getOperand(0)); setInsertPointAfterBundle(E->Scalars, S); - LHS = vectorizeTree(INVL); + LHS = vectorizeTree(E->getOperand(0)); } if (E->VectorizedValue) { @@ -3578,20 +4014,20 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { } Builder.SetInsertPoint(&F->getEntryBlock().front()); - auto *VectorRoot = vectorizeTree(&VectorizableTree[0]); + auto *VectorRoot = vectorizeTree(VectorizableTree[0].get()); // If the vectorized tree can be rewritten in a smaller type, we truncate the // vectorized root. InstCombine will then rewrite the entire expression. We // sign extend the extracted values below. - auto *ScalarRoot = VectorizableTree[0].Scalars[0]; + auto *ScalarRoot = VectorizableTree[0]->Scalars[0]; if (MinBWs.count(ScalarRoot)) { if (auto *I = dyn_cast<Instruction>(VectorRoot)) Builder.SetInsertPoint(&*++BasicBlock::iterator(I)); - auto BundleWidth = VectorizableTree[0].Scalars.size(); + auto BundleWidth = VectorizableTree[0]->Scalars.size(); auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); auto *VecTy = VectorType::get(MinTy, BundleWidth); auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy); - VectorizableTree[0].VectorizedValue = Trunc; + VectorizableTree[0]->VectorizedValue = Trunc; } LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() @@ -3687,8 +4123,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { } // For each vectorized value: - for (TreeEntry &EIdx : VectorizableTree) { - TreeEntry *Entry = &EIdx; + for (auto &TEPtr : VectorizableTree) { + TreeEntry *Entry = TEPtr.get(); // No need to handle users of gathered values. if (Entry->NeedToGather) @@ -3721,7 +4157,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Builder.ClearInsertionPoint(); - return VectorizableTree[0].VectorizedValue; + return VectorizableTree[0]->VectorizedValue; } void BoUpSLP::optimizeGatherSequence() { @@ -3767,10 +4203,10 @@ void BoUpSLP::optimizeGatherSequence() { // Sort blocks by domination. This ensures we visit a block after all blocks // dominating it are visited. - std::stable_sort(CSEWorkList.begin(), CSEWorkList.end(), - [this](const DomTreeNode *A, const DomTreeNode *B) { - return DT->properlyDominates(A, B); - }); + llvm::stable_sort(CSEWorkList, + [this](const DomTreeNode *A, const DomTreeNode *B) { + return DT->properlyDominates(A, B); + }); // Perform O(N^2) search over the gather sequences and merge identical // instructions. TODO: We can further optimize this scan if we split the @@ -3989,7 +4425,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, << "\n"); return true; } - UpIter++; + ++UpIter; } if (DownIter != LowerEnd) { if (&*DownIter == I) { @@ -4003,7 +4439,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, << "\n"); return true; } - DownIter++; + ++DownIter; } assert((UpIter != UpperEnd || DownIter != LowerEnd) && "instruction not found in block"); @@ -4253,7 +4689,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { BS->ScheduleStart = nullptr; } -unsigned BoUpSLP::getVectorElementSize(Value *V) { +unsigned BoUpSLP::getVectorElementSize(Value *V) const { // If V is a store, just return the width of the stored value without // traversing the expression tree. This is the common case. if (auto *Store = dyn_cast<StoreInst>(V)) @@ -4390,7 +4826,7 @@ void BoUpSLP::computeMinimumValueSizes() { return; // We only attempt to truncate integer expressions. - auto &TreeRoot = VectorizableTree[0].Scalars; + auto &TreeRoot = VectorizableTree[0]->Scalars; auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType()); if (!TreeRootIT) return; @@ -4411,8 +4847,8 @@ void BoUpSLP::computeMinimumValueSizes() { // Collect the scalar values of the vectorizable expression. We will use this // context to determine which values can be demoted. If we see a truncation, // we mark it as seeding another demotion. - for (auto &Entry : VectorizableTree) - Expr.insert(Entry.Scalars.begin(), Entry.Scalars.end()); + for (auto &EntryPtr : VectorizableTree) + Expr.insert(EntryPtr->Scalars.begin(), EntryPtr->Scalars.end()); // Ensure the roots of the vectorizable tree don't form a cycle. They must // have a single external user that is not in the vectorizable tree. @@ -4746,38 +5182,29 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, BoUpSLP::ValueSet VectorizedStores; bool Changed = false; - // Do a quadratic search on all of the given stores in reverse order and find - // all of the pairs of stores that follow each other. - SmallVector<unsigned, 16> IndexQueue; - unsigned E = Stores.size(); - IndexQueue.resize(E - 1); - for (unsigned I = E; I > 0; --I) { - unsigned Idx = I - 1; - // If a store has multiple consecutive store candidates, search Stores - // array according to the sequence: Idx-1, Idx+1, Idx-2, Idx+2, ... - // This is because usually pairing with immediate succeeding or preceding - // candidate create the best chance to find slp vectorization opportunity. - unsigned Offset = 1; - unsigned Cnt = 0; - for (unsigned J = 0; J < E - 1; ++J, ++Offset) { - if (Idx >= Offset) { - IndexQueue[Cnt] = Idx - Offset; - ++Cnt; - } - if (Idx + Offset < E) { - IndexQueue[Cnt] = Idx + Offset; - ++Cnt; - } - } + auto &&FindConsecutiveAccess = + [this, &Stores, &Heads, &Tails, &ConsecutiveChain] (int K, int Idx) { + if (!isConsecutiveAccess(Stores[K], Stores[Idx], *DL, *SE)) + return false; - for (auto K : IndexQueue) { - if (isConsecutiveAccess(Stores[K], Stores[Idx], *DL, *SE)) { Tails.insert(Stores[Idx]); Heads.insert(Stores[K]); ConsecutiveChain[Stores[K]] = Stores[Idx]; + return true; + }; + + // Do a quadratic search on all of the given stores in reverse order and find + // all of the pairs of stores that follow each other. + int E = Stores.size(); + for (int Idx = E - 1; Idx >= 0; --Idx) { + // If a store has multiple consecutive store candidates, search according + // to the sequence: Idx-1, Idx+1, Idx-2, Idx+2, ... + // This is because usually pairing with immediate succeeding or preceding + // candidate create the best chance to find slp vectorization opportunity. + for (int Offset = 1, F = std::max(E - Idx, Idx + 1); Offset < F; ++Offset) + if ((Idx >= Offset && FindConsecutiveAccess(Idx - Offset, Idx)) || + (Idx + Offset < E && FindConsecutiveAccess(Idx + Offset, Idx))) break; - } - } } // For stores that start but don't end a link in the chain: @@ -5740,6 +6167,9 @@ public: unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); Value *VectorizedTree = nullptr; + + // FIXME: Fast-math-flags should be set based on the instructions in the + // reduction (not all of 'fast' are required). IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); FastMathFlags Unsafe; Unsafe.setFast(); @@ -5929,10 +6359,14 @@ private: assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); - if (!IsPairwiseReduction) + if (!IsPairwiseReduction) { + // FIXME: The builder should use an FMF guard. It should not be hard-coded + // to 'fast'. + assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF"); return createSimpleTargetReduction( Builder, TTI, ReductionData.getOpcode(), VectorizedValue, ReductionData.getFlags(), ReductionOps.back()); + } Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { @@ -6256,7 +6690,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } // Sort by type. - std::stable_sort(Incoming.begin(), Incoming.end(), PhiTypeSorterFunc); + llvm::stable_sort(Incoming, PhiTypeSorterFunc); // Try to vectorize elements base on their type. for (SmallVector<Value *, 4>::iterator IncIt = Incoming.begin(), @@ -6297,7 +6731,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { SmallVector<WeakVH, 8> PostProcessInstructions; SmallDenseSet<Instruction *, 4> KeyNodes; - for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; it++) { + for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { // We may go through BB multiple times so skip the one we have checked. if (!VisitedInstrs.insert(&*it).second) { if (it->use_empty() && KeyNodes.count(&*it) > 0 && diff --git a/lib/Transforms/Vectorize/VPRecipeBuilder.h b/lib/Transforms/Vectorize/VPRecipeBuilder.h index 15d38ac9c84c..0ca6a6b93cfd 100644 --- a/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -1,9 +1,8 @@ //===- VPRecipeBuilder.h - Helper class to build recipes --------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// @@ -30,9 +29,6 @@ class VPRecipeBuilder { /// Target Library Info. const TargetLibraryInfo *TLI; - /// Target Transform Info. - const TargetTransformInfo *TTI; - /// The legality analysis. LoopVectorizationLegality *Legal; @@ -105,11 +101,9 @@ public: public: VPRecipeBuilder(Loop *OrigLoop, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, LoopVectorizationLegality *Legal, LoopVectorizationCostModel &CM, VPBuilder &Builder) - : OrigLoop(OrigLoop), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), - Builder(Builder) {} + : OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM), Builder(Builder) {} /// Check if a recipe can be create for \p I withing the given VF \p Range. /// If a recipe can be created, it adds it to \p VPBB. diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp index 05a5400beb4e..517d759d7bfc 100644 --- a/lib/Transforms/Vectorize/VPlan.cpp +++ b/lib/Transforms/Vectorize/VPlan.cpp @@ -1,9 +1,8 @@ //===- VPlan.cpp - Vectorizer Plan ----------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// @@ -374,10 +373,9 @@ void VPlan::execute(VPTransformState *State) { BasicBlock *VectorPreHeaderBB = State->CFG.PrevBB; BasicBlock *VectorHeaderBB = VectorPreHeaderBB->getSingleSuccessor(); assert(VectorHeaderBB && "Loop preheader does not have a single successor."); - BasicBlock *VectorLatchBB = VectorHeaderBB; // 1. Make room to generate basic-blocks inside loop body if needed. - VectorLatchBB = VectorHeaderBB->splitBasicBlock( + BasicBlock *VectorLatchBB = VectorHeaderBB->splitBasicBlock( VectorHeaderBB->getFirstInsertionPt(), "vector.body.latch"); Loop *L = State->LI->getLoopFor(VectorHeaderBB); L->addBasicBlockToLoop(VectorLatchBB, *State->LI); @@ -561,6 +559,19 @@ void VPlanPrinter::dumpBasicBlock(const VPBasicBlock *BasicBlock) { bumpIndent(1); OS << Indent << "\"" << DOT::EscapeString(BasicBlock->getName()) << ":\\n\""; bumpIndent(1); + + // Dump the block predicate. + const VPValue *Pred = BasicBlock->getPredicate(); + if (Pred) { + OS << " +\n" << Indent << " \"BlockPredicate: "; + if (const VPInstruction *PredI = dyn_cast<VPInstruction>(Pred)) { + PredI->printAsOperand(OS); + OS << " (" << DOT::EscapeString(PredI->getParent()->getName()) + << ")\\l\""; + } else + Pred->printAsOperand(OS); + } + for (const VPRecipeBase &Recipe : *BasicBlock) Recipe.print(OS, Indent); diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h index 5c1b4a83c30e..8a06412ad590 100644 --- a/lib/Transforms/Vectorize/VPlan.h +++ b/lib/Transforms/Vectorize/VPlan.h @@ -1,9 +1,8 @@ //===- VPlan.h - Represent A Vectorizer Plan --------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -353,6 +352,9 @@ private: /// Successor selector, null for zero or single successor blocks. VPValue *CondBit = nullptr; + /// Current block predicate - null if the block does not need a predicate. + VPValue *Predicate = nullptr; + /// Add \p Successor as the last successor to this block. void appendSuccessor(VPBlockBase *Successor) { assert(Successor && "Cannot add nullptr successor!"); @@ -491,6 +493,12 @@ public: void setCondBit(VPValue *CV) { CondBit = CV; } + VPValue *getPredicate() { return Predicate; } + + const VPValue *getPredicate() const { return Predicate; } + + void setPredicate(VPValue *Pred) { Predicate = Pred; } + /// Set a given VPBlockBase \p Successor as the single successor of this /// VPBlockBase. This VPBlockBase is not added as predecessor of \p Successor. /// This VPBlockBase must have no successors. @@ -521,6 +529,15 @@ public: appendPredecessor(Pred); } + /// Remove all the predecessor of this block. + void clearPredecessors() { Predecessors.clear(); } + + /// Remove all the successors of this block and set to null its condition bit + void clearSuccessors() { + Successors.clear(); + CondBit = nullptr; + } + /// The method which generates the output IR that correspond to this /// VPBlockBase, thereby "executing" the VPlan. virtual void execute(struct VPTransformState *State) = 0; @@ -1491,6 +1508,41 @@ public: From->removeSuccessor(To); To->removePredecessor(From); } + + /// Returns true if the edge \p FromBlock -> \p ToBlock is a back-edge. + static bool isBackEdge(const VPBlockBase *FromBlock, + const VPBlockBase *ToBlock, const VPLoopInfo *VPLI) { + assert(FromBlock->getParent() == ToBlock->getParent() && + FromBlock->getParent() && "Must be in same region"); + const VPLoop *FromLoop = VPLI->getLoopFor(FromBlock); + const VPLoop *ToLoop = VPLI->getLoopFor(ToBlock); + if (!FromLoop || !ToLoop || FromLoop != ToLoop) + return false; + + // A back-edge is a branch from the loop latch to its header. + return ToLoop->isLoopLatch(FromBlock) && ToBlock == ToLoop->getHeader(); + } + + /// Returns true if \p Block is a loop latch + static bool blockIsLoopLatch(const VPBlockBase *Block, + const VPLoopInfo *VPLInfo) { + if (const VPLoop *ParentVPL = VPLInfo->getLoopFor(Block)) + return ParentVPL->isLoopLatch(Block); + + return false; + } + + /// Count and return the number of succesors of \p PredBlock excluding any + /// backedges. + static unsigned countSuccessorsNoBE(VPBlockBase *PredBlock, + VPLoopInfo *VPLI) { + unsigned Count = 0; + for (VPBlockBase *SuccBlock : PredBlock->getSuccessors()) { + if (!VPBlockUtils::isBackEdge(PredBlock, SuccBlock, VPLI)) + Count++; + } + return Count; + } }; class VPInterleavedAccessInfo { diff --git a/lib/Transforms/Vectorize/VPlanDominatorTree.h b/lib/Transforms/Vectorize/VPlanDominatorTree.h index 1b81097b6d31..19f5d2c00c60 100644 --- a/lib/Transforms/Vectorize/VPlanDominatorTree.h +++ b/lib/Transforms/Vectorize/VPlanDominatorTree.h @@ -1,9 +1,8 @@ //===-- VPlanDominatorTree.h ------------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 0f42694e193b..df96f67288f1 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -1,9 +1,8 @@ //===-- VPlanHCFGBuilder.cpp ----------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// @@ -64,7 +63,9 @@ private: void setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB); void fixPhiNodes(); VPBasicBlock *getOrCreateVPBB(BasicBlock *BB); +#ifndef NDEBUG bool isExternalDef(Value *Val); +#endif VPValue *getOrCreateVPOperand(Value *IRVal); void createVPInstructionsForVPBB(VPBasicBlock *VPBB, BasicBlock *BB); @@ -119,6 +120,7 @@ VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { return VPBB; } +#ifndef NDEBUG // Return true if \p Val is considered an external definition. An external // definition is either: // 1. A Value that is not an Instruction. This will be refined in the future. @@ -154,6 +156,7 @@ bool PlainCFGBuilder::isExternalDef(Value *Val) { // Check whether Instruction definition is in loop body. return !TheLoop->contains(Inst); } +#endif // Create a new VPValue or retrieve an existing one for the Instruction's // operand \p IRVal. This function must only be used to create/retrieve VPValues diff --git a/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/lib/Transforms/Vectorize/VPlanHCFGBuilder.h index 3f11dcb5164d..238ee7e6347c 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGBuilder.h +++ b/lib/Transforms/Vectorize/VPlanHCFGBuilder.h @@ -1,9 +1,8 @@ //===-- VPlanHCFGBuilder.h --------------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp index 3ad7fc7e7b96..7ed7d21b6caa 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp @@ -1,9 +1,8 @@ //===-- VPlanHCFGTransforms.cpp - Utility VPlan to VPlan transforms -------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.h b/lib/Transforms/Vectorize/VPlanHCFGTransforms.h index ae549c6871b3..79a23c33184f 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGTransforms.h +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.h @@ -1,9 +1,8 @@ //===- VPlanHCFGTransforms.h - Utility VPlan to VPlan transforms ----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanLoopInfo.h b/lib/Transforms/Vectorize/VPlanLoopInfo.h index 5c2485fc2145..5208f2d58e2b 100644 --- a/lib/Transforms/Vectorize/VPlanLoopInfo.h +++ b/lib/Transforms/Vectorize/VPlanLoopInfo.h @@ -1,9 +1,8 @@ //===-- VPLoopInfo.h --------------------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanPredicator.cpp b/lib/Transforms/Vectorize/VPlanPredicator.cpp new file mode 100644 index 000000000000..7a80f3ff80a5 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanPredicator.cpp @@ -0,0 +1,248 @@ +//===-- VPlanPredicator.cpp -------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the VPlanPredicator class which contains the public +/// interfaces to predicate and linearize the VPlan region. +/// +//===----------------------------------------------------------------------===// + +#include "VPlanPredicator.h" +#include "VPlan.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "VPlanPredicator" + +using namespace llvm; + +// Generate VPInstructions at the beginning of CurrBB that calculate the +// predicate being propagated from PredBB to CurrBB depending on the edge type +// between them. For example if: +// i. PredBB is controlled by predicate %BP, and +// ii. The edge PredBB->CurrBB is the false edge, controlled by the condition +// bit value %CBV then this function will generate the following two +// VPInstructions at the start of CurrBB: +// %IntermediateVal = not %CBV +// %FinalVal = and %BP %IntermediateVal +// It returns %FinalVal. +VPValue *VPlanPredicator::getOrCreateNotPredicate(VPBasicBlock *PredBB, + VPBasicBlock *CurrBB) { + VPValue *CBV = PredBB->getCondBit(); + + // Set the intermediate value - this is either 'CBV', or 'not CBV' + // depending on the edge type. + EdgeType ET = getEdgeTypeBetween(PredBB, CurrBB); + VPValue *IntermediateVal = nullptr; + switch (ET) { + case EdgeType::TRUE_EDGE: + // CurrBB is the true successor of PredBB - nothing to do here. + IntermediateVal = CBV; + break; + + case EdgeType::FALSE_EDGE: + // CurrBB is the False successor of PredBB - compute not of CBV. + IntermediateVal = Builder.createNot(CBV); + break; + } + + // Now AND intermediate value with PredBB's block predicate if it has one. + VPValue *BP = PredBB->getPredicate(); + if (BP) + return Builder.createAnd(BP, IntermediateVal); + else + return IntermediateVal; +} + +// Generate a tree of ORs for all IncomingPredicates in WorkList. +// Note: This function destroys the original Worklist. +// +// P1 P2 P3 P4 P5 +// \ / \ / / +// OR1 OR2 / +// \ | / +// \ +/-+ +// \ / | +// OR3 | +// \ | +// OR4 <- Returns this +// | +// +// The algorithm uses a worklist of predicates as its main data structure. +// We pop a pair of values from the front (e.g. P1 and P2), generate an OR +// (in this example OR1), and push it back. In this example the worklist +// contains {P3, P4, P5, OR1}. +// The process iterates until we have only one element in the Worklist (OR4). +// The last element is the root predicate which is returned. +VPValue *VPlanPredicator::genPredicateTree(std::list<VPValue *> &Worklist) { + if (Worklist.empty()) + return nullptr; + + // The worklist initially contains all the leaf nodes. Initialize the tree + // using them. + while (Worklist.size() >= 2) { + // Pop a pair of values from the front. + VPValue *LHS = Worklist.front(); + Worklist.pop_front(); + VPValue *RHS = Worklist.front(); + Worklist.pop_front(); + + // Create an OR of these values. + VPValue *Or = Builder.createOr(LHS, RHS); + + // Push OR to the back of the worklist. + Worklist.push_back(Or); + } + + assert(Worklist.size() == 1 && "Expected 1 item in worklist"); + + // The root is the last node in the worklist. + VPValue *Root = Worklist.front(); + + // This root needs to replace the existing block predicate. This is done in + // the caller function. + return Root; +} + +// Return whether the edge FromBlock -> ToBlock is a TRUE_EDGE or FALSE_EDGE +VPlanPredicator::EdgeType +VPlanPredicator::getEdgeTypeBetween(VPBlockBase *FromBlock, + VPBlockBase *ToBlock) { + unsigned Count = 0; + for (VPBlockBase *SuccBlock : FromBlock->getSuccessors()) { + if (SuccBlock == ToBlock) { + assert(Count < 2 && "Switch not supported currently"); + return (Count == 0) ? EdgeType::TRUE_EDGE : EdgeType::FALSE_EDGE; + } + Count++; + } + + llvm_unreachable("Broken getEdgeTypeBetween"); +} + +// Generate all predicates needed for CurrBlock by going through its immediate +// predecessor blocks. +void VPlanPredicator::createOrPropagatePredicates(VPBlockBase *CurrBlock, + VPRegionBlock *Region) { + // Blocks that dominate region exit inherit the predicate from the region. + // Return after setting the predicate. + if (VPDomTree.dominates(CurrBlock, Region->getExit())) { + VPValue *RegionBP = Region->getPredicate(); + CurrBlock->setPredicate(RegionBP); + return; + } + + // Collect all incoming predicates in a worklist. + std::list<VPValue *> IncomingPredicates; + + // Set the builder's insertion point to the top of the current BB + VPBasicBlock *CurrBB = cast<VPBasicBlock>(CurrBlock->getEntryBasicBlock()); + Builder.setInsertPoint(CurrBB, CurrBB->begin()); + + // For each predecessor, generate the VPInstructions required for + // computing 'BP AND (not) CBV" at the top of CurrBB. + // Collect the outcome of this calculation for all predecessors + // into IncomingPredicates. + for (VPBlockBase *PredBlock : CurrBlock->getPredecessors()) { + // Skip back-edges + if (VPBlockUtils::isBackEdge(PredBlock, CurrBlock, VPLI)) + continue; + + VPValue *IncomingPredicate = nullptr; + unsigned NumPredSuccsNoBE = + VPBlockUtils::countSuccessorsNoBE(PredBlock, VPLI); + + // If there is an unconditional branch to the currBB, then we don't create + // edge predicates. We use the predecessor's block predicate instead. + if (NumPredSuccsNoBE == 1) + IncomingPredicate = PredBlock->getPredicate(); + else if (NumPredSuccsNoBE == 2) { + // Emit recipes into CurrBlock if required + assert(isa<VPBasicBlock>(PredBlock) && "Only BBs have multiple exits"); + IncomingPredicate = + getOrCreateNotPredicate(cast<VPBasicBlock>(PredBlock), CurrBB); + } else + llvm_unreachable("FIXME: switch statement ?"); + + if (IncomingPredicate) + IncomingPredicates.push_back(IncomingPredicate); + } + + // Logically OR all incoming predicates by building the Predicate Tree. + VPValue *Predicate = genPredicateTree(IncomingPredicates); + + // Now update the block's predicate with the new one. + CurrBlock->setPredicate(Predicate); +} + +// Generate all predicates needed for Region. +void VPlanPredicator::predicateRegionRec(VPRegionBlock *Region) { + VPBasicBlock *EntryBlock = cast<VPBasicBlock>(Region->getEntry()); + ReversePostOrderTraversal<VPBlockBase *> RPOT(EntryBlock); + + // Generate edge predicates and append them to the block predicate. RPO is + // necessary since the predecessor blocks' block predicate needs to be set + // before the current block's block predicate can be computed. + for (VPBlockBase *Block : make_range(RPOT.begin(), RPOT.end())) { + // TODO: Handle nested regions once we start generating the same. + assert(!isa<VPRegionBlock>(Block) && "Nested region not expected"); + createOrPropagatePredicates(Block, Region); + } +} + +// Linearize the CFG within Region. +// TODO: Predication and linearization need RPOT for every region. +// This traversal is expensive. Since predication is not adding new +// blocks, we should be able to compute RPOT once in predication and +// reuse it here. This becomes even more important once we have nested +// regions. +void VPlanPredicator::linearizeRegionRec(VPRegionBlock *Region) { + ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry()); + VPBlockBase *PrevBlock = nullptr; + + for (VPBlockBase *CurrBlock : make_range(RPOT.begin(), RPOT.end())) { + // TODO: Handle nested regions once we start generating the same. + assert(!isa<VPRegionBlock>(CurrBlock) && "Nested region not expected"); + + // Linearize control flow by adding an unconditional edge between PrevBlock + // and CurrBlock skipping loop headers and latches to keep intact loop + // header predecessors and loop latch successors. + if (PrevBlock && !VPLI->isLoopHeader(CurrBlock) && + !VPBlockUtils::blockIsLoopLatch(PrevBlock, VPLI)) { + + LLVM_DEBUG(dbgs() << "Linearizing: " << PrevBlock->getName() << "->" + << CurrBlock->getName() << "\n"); + + PrevBlock->clearSuccessors(); + CurrBlock->clearPredecessors(); + VPBlockUtils::connectBlocks(PrevBlock, CurrBlock); + } + + PrevBlock = CurrBlock; + } +} + +// Entry point. The driver function for the predicator. +void VPlanPredicator::predicate(void) { + // Predicate the blocks within Region. + predicateRegionRec(cast<VPRegionBlock>(Plan.getEntry())); + + // Linearlize the blocks with Region. + linearizeRegionRec(cast<VPRegionBlock>(Plan.getEntry())); +} + +VPlanPredicator::VPlanPredicator(VPlan &Plan) + : Plan(Plan), VPLI(&(Plan.getVPLoopInfo())) { + // FIXME: Predicator is currently computing the dominator information for the + // top region. Once we start storing dominator information in a VPRegionBlock, + // we can avoid this recalculation. + VPDomTree.recalculate(*(cast<VPRegionBlock>(Plan.getEntry()))); +} diff --git a/lib/Transforms/Vectorize/VPlanPredicator.h b/lib/Transforms/Vectorize/VPlanPredicator.h new file mode 100644 index 000000000000..692afd2978d5 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanPredicator.h @@ -0,0 +1,74 @@ +//===-- VPlanPredicator.h ---------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the VPlanPredicator class which contains the public +/// interfaces to predicate and linearize the VPlan region. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H + +#include "LoopVectorizationPlanner.h" +#include "VPlan.h" +#include "VPlanDominatorTree.h" + +namespace llvm { + +class VPlanPredicator { +private: + enum class EdgeType { + TRUE_EDGE, + FALSE_EDGE, + }; + + // VPlan being predicated. + VPlan &Plan; + + // VPLoopInfo for Plan's HCFG. + VPLoopInfo *VPLI; + + // Dominator tree for Plan's HCFG. + VPDominatorTree VPDomTree; + + // VPlan builder used to generate VPInstructions for block predicates. + VPBuilder Builder; + + /// Get the type of edge from \p FromBlock to \p ToBlock. Returns TRUE_EDGE if + /// \p ToBlock is either the unconditional successor or the conditional true + /// successor of \p FromBlock and FALSE_EDGE otherwise. + EdgeType getEdgeTypeBetween(VPBlockBase *FromBlock, VPBlockBase *ToBlock); + + /// Create and return VPValue corresponding to the predicate for the edge from + /// \p PredBB to \p CurrentBlock. + VPValue *getOrCreateNotPredicate(VPBasicBlock *PredBB, VPBasicBlock *CurrBB); + + /// Generate and return the result of ORing all the predicate VPValues in \p + /// Worklist. + VPValue *genPredicateTree(std::list<VPValue *> &Worklist); + + /// Create or propagate predicate for \p CurrBlock in region \p Region using + /// predicate(s) of its predecessor(s) + void createOrPropagatePredicates(VPBlockBase *CurrBlock, + VPRegionBlock *Region); + + /// Predicate the CFG within \p Region. + void predicateRegionRec(VPRegionBlock *Region); + + /// Linearize the CFG within \p Region. + void linearizeRegionRec(VPRegionBlock *Region); + +public: + VPlanPredicator(VPlan &Plan); + + /// Predicate Plan's HCFG. + void predicate(void); +}; +} // end namespace llvm +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H diff --git a/lib/Transforms/Vectorize/VPlanSLP.cpp b/lib/Transforms/Vectorize/VPlanSLP.cpp index ad3a85a6f760..e5ab24e52df6 100644 --- a/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -1,9 +1,8 @@ //===- VPlanSLP.cpp - SLP Analysis based on VPlan -------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// This file implements SLP analysis based on VPlan. The analysis is based on diff --git a/lib/Transforms/Vectorize/VPlanValue.h b/lib/Transforms/Vectorize/VPlanValue.h index b473579b699f..7b6c228c229e 100644 --- a/lib/Transforms/Vectorize/VPlanValue.h +++ b/lib/Transforms/Vectorize/VPlanValue.h @@ -1,9 +1,8 @@ //===- VPlanValue.h - Represent Values in Vectorizer Plan -----------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanVerifier.cpp b/lib/Transforms/Vectorize/VPlanVerifier.cpp index 054bed4e177f..394b1b93113b 100644 --- a/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -1,9 +1,8 @@ //===-- VPlanVerifier.cpp -------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/VPlanVerifier.h b/lib/Transforms/Vectorize/VPlanVerifier.h index d2f99d006a66..7d2b26252172 100644 --- a/lib/Transforms/Vectorize/VPlanVerifier.h +++ b/lib/Transforms/Vectorize/VPlanVerifier.h @@ -1,9 +1,8 @@ //===-- VPlanVerifier.h -----------------------------------------*- C++ -*-===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// diff --git a/lib/Transforms/Vectorize/Vectorize.cpp b/lib/Transforms/Vectorize/Vectorize.cpp index 559ab1968844..6a4f9169c2af 100644 --- a/lib/Transforms/Vectorize/Vectorize.cpp +++ b/lib/Transforms/Vectorize/Vectorize.cpp @@ -1,9 +1,8 @@ //===-- Vectorize.cpp -----------------------------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // |