diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 2074 |
1 files changed, 2074 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp new file mode 100644 index 000000000000..8573b016d1e5 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -0,0 +1,2074 @@ +//===- ComplexDeinterleavingPass.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Identification: +// This step is responsible for finding the patterns that can be lowered to +// complex instructions, and building a graph to represent the complex +// structures. Starting from the "Converging Shuffle" (a shuffle that +// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the +// operands are evaluated and identified as "Composite Nodes" (collections of +// instructions that can potentially be lowered to a single complex +// instruction). This is performed by checking the real and imaginary components +// and tracking the data flow for each component while following the operand +// pairs. Validity of each node is expected to be done upon creation, and any +// validation errors should halt traversal and prevent further graph +// construction. +// Instead of relying on Shuffle operations, vector interleaving and +// deinterleaving can be represented by vector.interleave2 and +// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by +// these intrinsics, whereas, fixed-width vectors are recognized for both +// shufflevector instruction and intrinsics. +// +// Replacement: +// This step traverses the graph built up by identification, delegating to the +// target to validate and generate the correct intrinsics, and plumbs them +// together connecting each end of the new intrinsics graph to the existing +// use-def chain. This step is assumed to finish successfully, as all +// information is expected to be correct by this point. +// +// +// Internal data structure: +// ComplexDeinterleavingGraph: +// Keeps references to all the valid CompositeNodes formed as part of the +// transformation, and every Instruction contained within said nodes. It also +// holds onto a reference to the root Instruction, and the root node that should +// replace it. +// +// ComplexDeinterleavingCompositeNode: +// A CompositeNode represents a single transformation point; each node should +// transform into a single complex instruction (ignoring vector splitting, which +// would generate more instructions per node). They are identified in a +// depth-first manner, traversing and identifying the operands of each +// instruction in the order they appear in the IR. +// Each node maintains a reference to its Real and Imaginary instructions, +// as well as any additional instructions that make up the identified operation +// (Internal instructions should only have uses within their containing node). +// A Node also contains the rotation and operation type that it represents. +// Operands contains pointers to other CompositeNodes, acting as the edges in +// the graph. ReplacementValue is the transformed Value* that has been emitted +// to the IR. +// +// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and +// ReplacementValue fields of that Node are relevant, where the ReplacementValue +// should be pre-populated. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ComplexDeinterleavingPass.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "complex-deinterleaving" + +STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); + +static cl::opt<bool> ComplexDeinterleavingEnabled( + "enable-complex-deinterleaving", + cl::desc("Enable generation of complex instructions"), cl::init(true), + cl::Hidden); + +/// Checks the given mask, and determines whether said mask is interleaving. +/// +/// To be interleaving, a mask must alternate between `i` and `i + (Length / +/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a +/// 4x vector interleaving mask would be <0, 2, 1, 3>). +static bool isInterleavingMask(ArrayRef<int> Mask); + +/// Checks the given mask, and determines whether said mask is deinterleaving. +/// +/// To be deinterleaving, a mask must increment in steps of 2, and either start +/// with 0 or 1. +/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or +/// <1, 3, 5, 7>). +static bool isDeinterleavingMask(ArrayRef<int> Mask); + +/// Returns true if the operation is a negation of V, and it works for both +/// integers and floats. +static bool isNeg(Value *V); + +/// Returns the operand for negation operation. +static Value *getNegOperand(Value *V); + +namespace { + +class ComplexDeinterleavingLegacyPass : public FunctionPass { +public: + static char ID; + + ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) + : FunctionPass(ID), TM(TM) { + initializeComplexDeinterleavingLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { + return "Complex Deinterleaving Pass"; + } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesCFG(); + } + +private: + const TargetMachine *TM; +}; + +class ComplexDeinterleavingGraph; +struct ComplexDeinterleavingCompositeNode { + + ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, + Value *R, Value *I) + : Operation(Op), Real(R), Imag(I) {} + +private: + friend class ComplexDeinterleavingGraph; + using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; + using RawNodePtr = ComplexDeinterleavingCompositeNode *; + +public: + ComplexDeinterleavingOperation Operation; + Value *Real; + Value *Imag; + + // This two members are required exclusively for generating + // ComplexDeinterleavingOperation::Symmetric operations. + unsigned Opcode; + std::optional<FastMathFlags> Flags; + + ComplexDeinterleavingRotation Rotation = + ComplexDeinterleavingRotation::Rotation_0; + SmallVector<RawNodePtr> Operands; + Value *ReplacementNode = nullptr; + + void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } + + void dump() { dump(dbgs()); } + void dump(raw_ostream &OS) { + auto PrintValue = [&](Value *V) { + if (V) { + OS << "\""; + V->print(OS, true); + OS << "\"\n"; + } else + OS << "nullptr\n"; + }; + auto PrintNodeRef = [&](RawNodePtr Ptr) { + if (Ptr) + OS << Ptr << "\n"; + else + OS << "nullptr\n"; + }; + + OS << "- CompositeNode: " << this << "\n"; + OS << " Real: "; + PrintValue(Real); + OS << " Imag: "; + PrintValue(Imag); + OS << " ReplacementNode: "; + PrintValue(ReplacementNode); + OS << " Operation: " << (int)Operation << "\n"; + OS << " Rotation: " << ((int)Rotation * 90) << "\n"; + OS << " Operands: \n"; + for (const auto &Op : Operands) { + OS << " - "; + PrintNodeRef(Op); + } + } +}; + +class ComplexDeinterleavingGraph { +public: + struct Product { + Value *Multiplier; + Value *Multiplicand; + bool IsPositive; + }; + + using Addend = std::pair<Value *, bool>; + using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; + using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; + + // Helper struct for holding info about potential partial multiplication + // candidates + struct PartialMulCandidate { + Value *Common; + NodePtr Node; + unsigned RealIdx; + unsigned ImagIdx; + bool IsNodeInverted; + }; + + explicit ComplexDeinterleavingGraph(const TargetLowering *TL, + const TargetLibraryInfo *TLI) + : TL(TL), TLI(TLI) {} + +private: + const TargetLowering *TL = nullptr; + const TargetLibraryInfo *TLI = nullptr; + SmallVector<NodePtr> CompositeNodes; + DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult; + + SmallPtrSet<Instruction *, 16> FinalInstructions; + + /// Root instructions are instructions from which complex computation starts + std::map<Instruction *, NodePtr> RootToNode; + + /// Topologically sorted root instructions + SmallVector<Instruction *, 1> OrderedRoots; + + /// When examining a basic block for complex deinterleaving, if it is a simple + /// one-block loop, then the only incoming block is 'Incoming' and the + /// 'BackEdge' block is the block itself." + BasicBlock *BackEdge = nullptr; + BasicBlock *Incoming = nullptr; + + /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction + /// %OutsideUser as it is shown in the IR: + /// + /// vector.body: + /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], + /// [ %ReductionOp, %vector.body ] + /// ... + /// %ReductionOp = fadd i64 ... + /// ... + /// br i1 %condition, label %vector.body, %middle.block + /// + /// middle.block: + /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) + /// + /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding + /// `llvm.vector.reduce.fadd` when unroll factor isn't one. + MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; + + /// In the process of detecting a reduction, we consider a pair of + /// %ReductionOP, which we refer to as real and imag (or vice versa), and + /// traverse the use-tree to detect complex operations. As this is a reduction + /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds + /// to the %ReductionOPs that we suspect to be complex. + /// RealPHI and ImagPHI are used by the identifyPHINode method. + PHINode *RealPHI = nullptr; + PHINode *ImagPHI = nullptr; + + /// Set this flag to true if RealPHI and ImagPHI were reached during reduction + /// detection. + bool PHIsFound = false; + + /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. + /// The new PHINode corresponds to a vector of deinterleaved complex numbers. + /// This mapping is populated during + /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then + /// used in the ComplexDeinterleavingOperation::ReductionOperation node + /// replacement process. + std::map<PHINode *, PHINode *> OldToNewPHI; + + NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, + Value *R, Value *I) { + assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && + Operation != ComplexDeinterleavingOperation::ReductionOperation) || + (R && I)) && + "Reduction related nodes must have Real and Imaginary parts"); + return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, + I); + } + + NodePtr submitCompositeNode(NodePtr Node) { + CompositeNodes.push_back(Node); + if (Node->Real && Node->Imag) + CachedResult[{Node->Real, Node->Imag}] = Node; + return Node; + } + + /// Identifies a complex partial multiply pattern and its rotation, based on + /// the following patterns + /// + /// 0: r: cr + ar * br + /// i: ci + ar * bi + /// 90: r: cr - ai * bi + /// i: ci + ai * br + /// 180: r: cr - ar * br + /// i: ci - ar * bi + /// 270: r: cr + ai * bi + /// i: ci - ai * br + NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); + + /// Identify the other branch of a Partial Mul, taking the CommonOperandI that + /// is partially known from identifyPartialMul, filling in the other half of + /// the complex pair. + NodePtr + identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, + std::pair<Value *, Value *> &CommonOperandI); + + /// Identifies a complex add pattern and its rotation, based on the following + /// patterns. + /// + /// 90: r: ar - bi + /// i: ai + br + /// 270: r: ar + bi + /// i: ai - br + NodePtr identifyAdd(Instruction *Real, Instruction *Imag); + NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); + + NodePtr identifyNode(Value *R, Value *I); + + /// Determine if a sum of complex numbers can be formed from \p RealAddends + /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. + /// Return nullptr if it is not possible to construct a complex number. + /// \p Flags are needed to generate symmetric Add and Sub operations. + NodePtr identifyAdditions(std::list<Addend> &RealAddends, + std::list<Addend> &ImagAddends, + std::optional<FastMathFlags> Flags, + NodePtr Accumulator); + + /// Extract one addend that have both real and imaginary parts positive. + NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, + std::list<Addend> &ImagAddends); + + /// Determine if sum of multiplications of complex numbers can be formed from + /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result + /// to it. Return nullptr if it is not possible to construct a complex number. + NodePtr identifyMultiplications(std::vector<Product> &RealMuls, + std::vector<Product> &ImagMuls, + NodePtr Accumulator); + + /// Go through pairs of multiplication (one Real and one Imag) and find all + /// possible candidates for partial multiplication and put them into \p + /// Candidates. Returns true if all Product has pair with common operand + bool collectPartialMuls(const std::vector<Product> &RealMuls, + const std::vector<Product> &ImagMuls, + std::vector<PartialMulCandidate> &Candidates); + + /// If the code is compiled with -Ofast or expressions have `reassoc` flag, + /// the order of complex computation operations may be significantly altered, + /// and the real and imaginary parts may not be executed in parallel. This + /// function takes this into consideration and employs a more general approach + /// to identify complex computations. Initially, it gathers all the addends + /// and multiplicands and then constructs a complex expression from them. + NodePtr identifyReassocNodes(Instruction *I, Instruction *J); + + NodePtr identifyRoot(Instruction *I); + + /// Identifies the Deinterleave operation applied to a vector containing + /// complex numbers. There are two ways to represent the Deinterleave + /// operation: + /// * Using two shufflevectors with even indices for /pReal instruction and + /// odd indices for /pImag instructions (only for fixed-width vectors) + /// * Using two extractvalue instructions applied to `vector.deinterleave2` + /// intrinsic (for both fixed and scalable vectors) + NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); + + /// identifying the operation that represents a complex number repeated in a + /// Splat vector. There are two possible types of splats: ConstantExpr with + /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an + /// initialization mask with all values set to zero. + NodePtr identifySplat(Value *Real, Value *Imag); + + NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); + + /// Identifies SelectInsts in a loop that has reduction with predication masks + /// and/or predicated tail folding + NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); + + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); + + /// Complete IR modifications after producing new reduction operation: + /// * Populate the PHINode generated for + /// ComplexDeinterleavingOperation::ReductionPHI + /// * Deinterleave the final value outside of the loop and repurpose original + /// reduction users + void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); + +public: + void dump() { dump(dbgs()); } + void dump(raw_ostream &OS) { + for (const auto &Node : CompositeNodes) + Node->dump(OS); + } + + /// Returns false if the deinterleaving operation should be cancelled for the + /// current graph. + bool identifyNodes(Instruction *RootI); + + /// In case \pB is one-block loop, this function seeks potential reductions + /// and populates ReductionInfo. Returns true if any reductions were + /// identified. + bool collectPotentialReductions(BasicBlock *B); + + void identifyReductionNodes(); + + /// Check that every instruction, from the roots to the leaves, has internal + /// uses. + bool checkNodes(); + + /// Perform the actual replacement of the underlying instruction graph. + void replaceNodes(); +}; + +class ComplexDeinterleaving { +public: + ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) + : TL(tl), TLI(tli) {} + bool runOnFunction(Function &F); + +private: + bool evaluateBasicBlock(BasicBlock *B); + + const TargetLowering *TL = nullptr; + const TargetLibraryInfo *TLI = nullptr; +}; + +} // namespace + +char ComplexDeinterleavingLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, + "Complex Deinterleaving", false, false) +INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, + "Complex Deinterleaving", false, false) + +PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, + FunctionAnalysisManager &AM) { + const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); + auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); + if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<FunctionAnalysisManagerModuleProxy>(); + return PA; +} + +FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { + return new ComplexDeinterleavingLegacyPass(TM); +} + +bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { + const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); + auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); +} + +bool ComplexDeinterleaving::runOnFunction(Function &F) { + if (!ComplexDeinterleavingEnabled) { + LLVM_DEBUG( + dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); + return false; + } + + if (!TL->isComplexDeinterleavingSupported()) { + LLVM_DEBUG( + dbgs() << "Complex deinterleaving has been disabled, target does " + "not support lowering of complex number operations.\n"); + return false; + } + + bool Changed = false; + for (auto &B : F) + Changed |= evaluateBasicBlock(&B); + + return Changed; +} + +static bool isInterleavingMask(ArrayRef<int> Mask) { + // If the size is not even, it's not an interleaving mask + if ((Mask.size() & 1)) + return false; + + int HalfNumElements = Mask.size() / 2; + for (int Idx = 0; Idx < HalfNumElements; ++Idx) { + int MaskIdx = Idx * 2; + if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) + return false; + } + + return true; +} + +static bool isDeinterleavingMask(ArrayRef<int> Mask) { + int Offset = Mask[0]; + int HalfNumElements = Mask.size() / 2; + + for (int Idx = 1; Idx < HalfNumElements; ++Idx) { + if (Mask[Idx] != (Idx * 2) + Offset) + return false; + } + + return true; +} + +bool isNeg(Value *V) { + return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value())); +} + +Value *getNegOperand(Value *V) { + assert(isNeg(V)); + auto *I = cast<Instruction>(V); + if (I->getOpcode() == Instruction::FNeg) + return I->getOperand(0); + + return I->getOperand(1); +} + +bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { + ComplexDeinterleavingGraph Graph(TL, TLI); + if (Graph.collectPotentialReductions(B)) + Graph.identifyReductionNodes(); + + for (auto &I : *B) + Graph.identifyNodes(&I); + + if (Graph.checkNodes()) { + Graph.replaceNodes(); + return true; + } + + return false; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( + Instruction *Real, Instruction *Imag, + std::pair<Value *, Value *> &PartialMatch) { + LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag + << "\n"); + + if (!Real->hasOneUse() || !Imag->hasOneUse()) { + LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); + return nullptr; + } + + if ((Real->getOpcode() != Instruction::FMul && + Real->getOpcode() != Instruction::Mul) || + (Imag->getOpcode() != Instruction::FMul && + Imag->getOpcode() != Instruction::Mul)) { + LLVM_DEBUG( + dbgs() << " - Real or imaginary instruction is not fmul or mul\n"); + return nullptr; + } + + Value *R0 = Real->getOperand(0); + Value *R1 = Real->getOperand(1); + Value *I0 = Imag->getOperand(0); + Value *I1 = Imag->getOperand(1); + + // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the + // rotations and use the operand. + unsigned Negs = 0; + Value *Op; + if (match(R0, m_Neg(m_Value(Op)))) { + Negs |= 1; + R0 = Op; + } else if (match(R1, m_Neg(m_Value(Op)))) { + Negs |= 1; + R1 = Op; + } + + if (isNeg(I0)) { + Negs |= 2; + Negs ^= 1; + I0 = Op; + } else if (match(I1, m_Neg(m_Value(Op)))) { + Negs |= 2; + Negs ^= 1; + I1 = Op; + } + + ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; + + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; + + if (R0 == I0 || R0 == I1) { + CommonOperand = R0; + UncommonRealOp = R1; + } else if (R1 == I0 || R1 == I1) { + CommonOperand = R1; + UncommonRealOp = R0; + } else { + LLVM_DEBUG(dbgs() << " - No equal operand\n"); + return nullptr; + } + + UncommonImagOp = (CommonOperand == I0) ? I1 : I0; + if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || + Rotation == ComplexDeinterleavingRotation::Rotation_270) + std::swap(UncommonRealOp, UncommonImagOp); + + // Between identifyPartialMul and here we need to have found a complete valid + // pair from the CommonOperand of each part. + if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || + Rotation == ComplexDeinterleavingRotation::Rotation_180) + PartialMatch.first = CommonOperand; + else + PartialMatch.second = CommonOperand; + + if (!PartialMatch.first || !PartialMatch.second) { + LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); + return nullptr; + } + + NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); + if (!CommonNode) { + LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); + return nullptr; + } + + NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); + if (!UncommonNode) { + LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); + return nullptr; + } + + NodePtr Node = prepareCompositeNode( + ComplexDeinterleavingOperation::CMulPartial, Real, Imag); + Node->Rotation = Rotation; + Node->addOperand(CommonNode); + Node->addOperand(UncommonNode); + return submitCompositeNode(Node); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, + Instruction *Imag) { + LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag + << "\n"); + // Determine rotation + auto IsAdd = [](unsigned Op) { + return Op == Instruction::FAdd || Op == Instruction::Add; + }; + auto IsSub = [](unsigned Op) { + return Op == Instruction::FSub || Op == Instruction::Sub; + }; + ComplexDeinterleavingRotation Rotation; + if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode())) + Rotation = ComplexDeinterleavingRotation::Rotation_0; + else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode())) + Rotation = ComplexDeinterleavingRotation::Rotation_90; + else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode())) + Rotation = ComplexDeinterleavingRotation::Rotation_180; + else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode())) + Rotation = ComplexDeinterleavingRotation::Rotation_270; + else { + LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); + return nullptr; + } + + if (isa<FPMathOperator>(Real) && + (!Real->getFastMathFlags().allowContract() || + !Imag->getFastMathFlags().allowContract())) { + LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); + return nullptr; + } + + Value *CR = Real->getOperand(0); + Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); + if (!RealMulI) + return nullptr; + Value *CI = Imag->getOperand(0); + Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); + if (!ImagMulI) + return nullptr; + + if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { + LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); + return nullptr; + } + + Value *R0 = RealMulI->getOperand(0); + Value *R1 = RealMulI->getOperand(1); + Value *I0 = ImagMulI->getOperand(0); + Value *I1 = ImagMulI->getOperand(1); + + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; + + if (R0 == I0 || R0 == I1) { + CommonOperand = R0; + UncommonRealOp = R1; + } else if (R1 == I0 || R1 == I1) { + CommonOperand = R1; + UncommonRealOp = R0; + } else { + LLVM_DEBUG(dbgs() << " - No equal operand\n"); + return nullptr; + } + + UncommonImagOp = (CommonOperand == I0) ? I1 : I0; + if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || + Rotation == ComplexDeinterleavingRotation::Rotation_270) + std::swap(UncommonRealOp, UncommonImagOp); + + std::pair<Value *, Value *> PartialMatch( + (Rotation == ComplexDeinterleavingRotation::Rotation_0 || + Rotation == ComplexDeinterleavingRotation::Rotation_180) + ? CommonOperand + : nullptr, + (Rotation == ComplexDeinterleavingRotation::Rotation_90 || + Rotation == ComplexDeinterleavingRotation::Rotation_270) + ? CommonOperand + : nullptr); + + auto *CRInst = dyn_cast<Instruction>(CR); + auto *CIInst = dyn_cast<Instruction>(CI); + + if (!CRInst || !CIInst) { + LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); + return nullptr; + } + + NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); + if (!CNode) { + LLVM_DEBUG(dbgs() << " - No cnode identified\n"); + return nullptr; + } + + NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); + if (!UncommonRes) { + LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); + return nullptr; + } + + assert(PartialMatch.first && PartialMatch.second); + NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); + if (!CommonRes) { + LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); + return nullptr; + } + + NodePtr Node = prepareCompositeNode( + ComplexDeinterleavingOperation::CMulPartial, Real, Imag); + Node->Rotation = Rotation; + Node->addOperand(CommonRes); + Node->addOperand(UncommonRes); + Node->addOperand(CNode); + return submitCompositeNode(Node); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { + LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); + + // Determine rotation + ComplexDeinterleavingRotation Rotation; + if ((Real->getOpcode() == Instruction::FSub && + Imag->getOpcode() == Instruction::FAdd) || + (Real->getOpcode() == Instruction::Sub && + Imag->getOpcode() == Instruction::Add)) + Rotation = ComplexDeinterleavingRotation::Rotation_90; + else if ((Real->getOpcode() == Instruction::FAdd && + Imag->getOpcode() == Instruction::FSub) || + (Real->getOpcode() == Instruction::Add && + Imag->getOpcode() == Instruction::Sub)) + Rotation = ComplexDeinterleavingRotation::Rotation_270; + else { + LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); + return nullptr; + } + + auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); + auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); + auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); + auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); + + if (!AR || !AI || !BR || !BI) { + LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); + return nullptr; + } + + NodePtr ResA = identifyNode(AR, AI); + if (!ResA) { + LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); + return nullptr; + } + NodePtr ResB = identifyNode(BR, BI); + if (!ResB) { + LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); + return nullptr; + } + + NodePtr Node = + prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); + Node->Rotation = Rotation; + Node->addOperand(ResA); + Node->addOperand(ResB); + return submitCompositeNode(Node); +} + +static bool isInstructionPairAdd(Instruction *A, Instruction *B) { + unsigned OpcA = A->getOpcode(); + unsigned OpcB = B->getOpcode(); + + return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || + (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || + (OpcA == Instruction::Sub && OpcB == Instruction::Add) || + (OpcA == Instruction::Add && OpcB == Instruction::Sub); +} + +static bool isInstructionPairMul(Instruction *A, Instruction *B) { + auto Pattern = + m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); + + return match(A, Pattern) && match(B, Pattern); +} + +static bool isInstructionPotentiallySymmetric(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FNeg: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + return true; + default: + return false; + } +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, + Instruction *Imag) { + if (Real->getOpcode() != Imag->getOpcode()) + return nullptr; + + if (!isInstructionPotentiallySymmetric(Real) || + !isInstructionPotentiallySymmetric(Imag)) + return nullptr; + + auto *R0 = Real->getOperand(0); + auto *I0 = Imag->getOperand(0); + + NodePtr Op0 = identifyNode(R0, I0); + NodePtr Op1 = nullptr; + if (Op0 == nullptr) + return nullptr; + + if (Real->isBinaryOp()) { + auto *R1 = Real->getOperand(1); + auto *I1 = Imag->getOperand(1); + Op1 = identifyNode(R1, I1); + if (Op1 == nullptr) + return nullptr; + } + + if (isa<FPMathOperator>(Real) && + Real->getFastMathFlags() != Imag->getFastMathFlags()) + return nullptr; + + auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, + Real, Imag); + Node->Opcode = Real->getOpcode(); + if (isa<FPMathOperator>(Real)) + Node->Flags = Real->getFastMathFlags(); + + Node->addOperand(Op0); + if (Real->isBinaryOp()) + Node->addOperand(Op1); + + return submitCompositeNode(Node); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { + LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); + assert(R->getType() == I->getType() && + "Real and imaginary parts should not have different types"); + + auto It = CachedResult.find({R, I}); + if (It != CachedResult.end()) { + LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); + return It->second; + } + + if (NodePtr CN = identifySplat(R, I)) + return CN; + + auto *Real = dyn_cast<Instruction>(R); + auto *Imag = dyn_cast<Instruction>(I); + if (!Real || !Imag) + return nullptr; + + if (NodePtr CN = identifyDeinterleave(Real, Imag)) + return CN; + + if (NodePtr CN = identifyPHINode(Real, Imag)) + return CN; + + if (NodePtr CN = identifySelectNode(Real, Imag)) + return CN; + + auto *VTy = cast<VectorType>(Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + + bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CMulPartial, NewVTy); + bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CAdd, NewVTy); + + if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { + if (NodePtr CN = identifyPartialMul(Real, Imag)) + return CN; + } + + if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { + if (NodePtr CN = identifyAdd(Real, Imag)) + return CN; + } + + if (HasCMulSupport && HasCAddSupport) { + if (NodePtr CN = identifyReassocNodes(Real, Imag)) + return CN; + } + + if (NodePtr CN = identifySymmetricOperation(Real, Imag)) + return CN; + + LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); + CachedResult[{R, I}] = nullptr; + return nullptr; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, + Instruction *Imag) { + auto IsOperationSupported = [](unsigned Opcode) -> bool { + return Opcode == Instruction::FAdd || Opcode == Instruction::FSub || + Opcode == Instruction::FNeg || Opcode == Instruction::Add || + Opcode == Instruction::Sub; + }; + + if (!IsOperationSupported(Real->getOpcode()) || + !IsOperationSupported(Imag->getOpcode())) + return nullptr; + + std::optional<FastMathFlags> Flags; + if (isa<FPMathOperator>(Real)) { + if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { + LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are " + "not identical\n"); + return nullptr; + } + + Flags = Real->getFastMathFlags(); + if (!Flags->allowReassoc()) { + LLVM_DEBUG( + dbgs() + << "the 'Reassoc' attribute is missing in the FastMath flags\n"); + return nullptr; + } + } + + // Collect multiplications and addend instructions from the given instruction + // while traversing it operands. Additionally, verify that all instructions + // have the same fast math flags. + auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, + std::list<Addend> &Addends) -> bool { + SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; + SmallPtrSet<Value *, 8> Visited; + while (!Worklist.empty()) { + auto [V, IsPositive] = Worklist.back(); + Worklist.pop_back(); + if (!Visited.insert(V).second) + continue; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) { + Addends.emplace_back(V, IsPositive); + continue; + } + + // If an instruction has more than one user, it indicates that it either + // has an external user, which will be later checked by the checkNodes + // function, or it is a subexpression utilized by multiple expressions. In + // the latter case, we will attempt to separately identify the complex + // operation from here in order to create a shared + // ComplexDeinterleavingCompositeNode. + if (I != Insn && I->getNumUses() > 1) { + LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); + Addends.emplace_back(I, IsPositive); + continue; + } + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::Add: + Worklist.emplace_back(I->getOperand(1), IsPositive); + Worklist.emplace_back(I->getOperand(0), IsPositive); + break; + case Instruction::FSub: + Worklist.emplace_back(I->getOperand(1), !IsPositive); + Worklist.emplace_back(I->getOperand(0), IsPositive); + break; + case Instruction::Sub: + if (isNeg(I)) { + Worklist.emplace_back(getNegOperand(I), !IsPositive); + } else { + Worklist.emplace_back(I->getOperand(1), !IsPositive); + Worklist.emplace_back(I->getOperand(0), IsPositive); + } + break; + case Instruction::FMul: + case Instruction::Mul: { + Value *A, *B; + if (isNeg(I->getOperand(0))) { + A = getNegOperand(I->getOperand(0)); + IsPositive = !IsPositive; + } else { + A = I->getOperand(0); + } + + if (isNeg(I->getOperand(1))) { + B = getNegOperand(I->getOperand(1)); + IsPositive = !IsPositive; + } else { + B = I->getOperand(1); + } + Muls.push_back(Product{A, B, IsPositive}); + break; + } + case Instruction::FNeg: + Worklist.emplace_back(I->getOperand(0), !IsPositive); + break; + default: + Addends.emplace_back(I, IsPositive); + continue; + } + + if (Flags && I->getFastMathFlags() != *Flags) { + LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " + "inconsistent with the root instructions' flags: " + << *I << "\n"); + return false; + } + } + return true; + }; + + std::vector<Product> RealMuls, ImagMuls; + std::list<Addend> RealAddends, ImagAddends; + if (!Collect(Real, RealMuls, RealAddends) || + !Collect(Imag, ImagMuls, ImagAddends)) + return nullptr; + + if (RealAddends.size() != ImagAddends.size()) + return nullptr; + + NodePtr FinalNode; + if (!RealMuls.empty() || !ImagMuls.empty()) { + // If there are multiplicands, extract positive addend and use it as an + // accumulator + FinalNode = extractPositiveAddend(RealAddends, ImagAddends); + FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); + if (!FinalNode) + return nullptr; + } + + // Identify and process remaining additions + if (!RealAddends.empty() || !ImagAddends.empty()) { + FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); + if (!FinalNode) + return nullptr; + } + assert(FinalNode && "FinalNode can not be nullptr here"); + // Set the Real and Imag fields of the final node and submit it + FinalNode->Real = Real; + FinalNode->Imag = Imag; + submitCompositeNode(FinalNode); + return FinalNode; +} + +bool ComplexDeinterleavingGraph::collectPartialMuls( + const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, + std::vector<PartialMulCandidate> &PartialMulCandidates) { + // Helper function to extract a common operand from two products + auto FindCommonInstruction = [](const Product &Real, + const Product &Imag) -> Value * { + if (Real.Multiplicand == Imag.Multiplicand || + Real.Multiplicand == Imag.Multiplier) + return Real.Multiplicand; + + if (Real.Multiplier == Imag.Multiplicand || + Real.Multiplier == Imag.Multiplier) + return Real.Multiplier; + + return nullptr; + }; + + // Iterating over real and imaginary multiplications to find common operands + // If a common operand is found, a partial multiplication candidate is created + // and added to the candidates vector The function returns false if no common + // operands are found for any product + for (unsigned i = 0; i < RealMuls.size(); ++i) { + bool FoundCommon = false; + for (unsigned j = 0; j < ImagMuls.size(); ++j) { + auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); + if (!Common) + continue; + + auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier + : RealMuls[i].Multiplicand; + auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier + : ImagMuls[j].Multiplicand; + + auto Node = identifyNode(A, B); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, false}); + } + + Node = identifyNode(B, A); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, true}); + } + } + if (!FoundCommon) + return false; + } + return true; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyMultiplications( + std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, + NodePtr Accumulator = nullptr) { + if (RealMuls.size() != ImagMuls.size()) + return nullptr; + + std::vector<PartialMulCandidate> Info; + if (!collectPartialMuls(RealMuls, ImagMuls, Info)) + return nullptr; + + // Map to store common instruction to node pointers + std::map<Value *, NodePtr> CommonToNode; + std::vector<bool> Processed(Info.size(), false); + for (unsigned I = 0; I < Info.size(); ++I) { + if (Processed[I]) + continue; + + PartialMulCandidate &InfoA = Info[I]; + for (unsigned J = I + 1; J < Info.size(); ++J) { + if (Processed[J]) + continue; + + PartialMulCandidate &InfoB = Info[J]; + auto *InfoReal = &InfoA; + auto *InfoImag = &InfoB; + + auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); + if (!NodeFromCommon) { + std::swap(InfoReal, InfoImag); + NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); + } + if (!NodeFromCommon) + continue; + + CommonToNode[InfoReal->Common] = NodeFromCommon; + CommonToNode[InfoImag->Common] = NodeFromCommon; + Processed[I] = true; + Processed[J] = true; + } + } + + std::vector<bool> ProcessedReal(RealMuls.size(), false); + std::vector<bool> ProcessedImag(ImagMuls.size(), false); + NodePtr Result = Accumulator; + for (auto &PMI : Info) { + if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) + continue; + + auto It = CommonToNode.find(PMI.Common); + // TODO: Process independent complex multiplications. Cases like this: + // A.real() * B where both A and B are complex numbers. + if (It == CommonToNode.end()) { + LLVM_DEBUG({ + dbgs() << "Unprocessed independent partial multiplication:\n"; + for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) + dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier + << " multiplied by " << *Mul->Multiplicand << "\n"; + }); + return nullptr; + } + + auto &RealMul = RealMuls[PMI.RealIdx]; + auto &ImagMul = ImagMuls[PMI.ImagIdx]; + + auto NodeA = It->second; + auto NodeB = PMI.Node; + auto IsMultiplicandReal = PMI.Common == NodeA->Real; + // The following table illustrates the relationship between multiplications + // and rotations. If we consider the multiplication (X + iY) * (U + iV), we + // can see: + // + // Rotation | Real | Imag | + // ---------+--------+--------+ + // 0 | x * u | x * v | + // 90 | -y * v | y * u | + // 180 | -x * u | -x * v | + // 270 | y * v | -y * u | + // + // Check if the candidate can indeed be represented by partial + // multiplication + // TODO: Add support for multiplication by complex one + if ((IsMultiplicandReal && PMI.IsNodeInverted) || + (!IsMultiplicandReal && !PMI.IsNodeInverted)) + continue; + + // Determine the rotation based on the multiplications + ComplexDeinterleavingRotation Rotation; + if (IsMultiplicandReal) { + // Detect 0 and 180 degrees rotation + if (RealMul.IsPositive && ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; + else if (!RealMul.IsPositive && !ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; + else + continue; + + } else { + // Detect 90 and 270 degrees rotation + if (!RealMul.IsPositive && ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; + else if (RealMul.IsPositive && !ImagMul.IsPositive) + Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; + else + continue; + } + + LLVM_DEBUG({ + dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; + dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; + dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; + dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; + dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; + dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; + }); + + NodePtr NodeMul = prepareCompositeNode( + ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); + NodeMul->Rotation = Rotation; + NodeMul->addOperand(NodeA); + NodeMul->addOperand(NodeB); + if (Result) + NodeMul->addOperand(Result); + submitCompositeNode(NodeMul); + Result = NodeMul; + ProcessedReal[PMI.RealIdx] = true; + ProcessedImag[PMI.ImagIdx] = true; + } + + // Ensure all products have been processed, if not return nullptr. + if (!all_of(ProcessedReal, [](bool V) { return V; }) || + !all_of(ProcessedImag, [](bool V) { return V; })) { + + // Dump debug information about which partial multiplications are not + // processed. + LLVM_DEBUG({ + dbgs() << "Unprocessed products (Real):\n"; + for (size_t i = 0; i < ProcessedReal.size(); ++i) { + if (!ProcessedReal[i]) + dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") + << *RealMuls[i].Multiplier << " multiplied by " + << *RealMuls[i].Multiplicand << "\n"; + } + dbgs() << "Unprocessed products (Imag):\n"; + for (size_t i = 0; i < ProcessedImag.size(); ++i) { + if (!ProcessedImag[i]) + dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") + << *ImagMuls[i].Multiplier << " multiplied by " + << *ImagMuls[i].Multiplicand << "\n"; + } + }); + return nullptr; + } + + return Result; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyAdditions( + std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends, + std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) { + if (RealAddends.size() != ImagAddends.size()) + return nullptr; + + NodePtr Result; + // If we have accumulator use it as first addend + if (Accumulator) + Result = Accumulator; + // Otherwise find an element with both positive real and imaginary parts. + else + Result = extractPositiveAddend(RealAddends, ImagAddends); + + if (!Result) + return nullptr; + + while (!RealAddends.empty()) { + auto ItR = RealAddends.begin(); + auto [R, IsPositiveR] = *ItR; + + bool FoundImag = false; + for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { + auto [I, IsPositiveI] = *ItI; + ComplexDeinterleavingRotation Rotation; + if (IsPositiveR && IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_0; + else if (!IsPositiveR && IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_90; + else if (!IsPositiveR && !IsPositiveI) + Rotation = ComplexDeinterleavingRotation::Rotation_180; + else + Rotation = ComplexDeinterleavingRotation::Rotation_270; + + NodePtr AddNode; + if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || + Rotation == ComplexDeinterleavingRotation::Rotation_180) { + AddNode = identifyNode(R, I); + } else { + AddNode = identifyNode(I, R); + } + if (AddNode) { + LLVM_DEBUG({ + dbgs() << "Identified addition:\n"; + dbgs().indent(4) << "X: " << *R << "\n"; + dbgs().indent(4) << "Y: " << *I << "\n"; + dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; + }); + + NodePtr TmpNode; + if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { + TmpNode = prepareCompositeNode( + ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); + if (Flags) { + TmpNode->Opcode = Instruction::FAdd; + TmpNode->Flags = *Flags; + } else { + TmpNode->Opcode = Instruction::Add; + } + } else if (Rotation == + llvm::ComplexDeinterleavingRotation::Rotation_180) { + TmpNode = prepareCompositeNode( + ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); + if (Flags) { + TmpNode->Opcode = Instruction::FSub; + TmpNode->Flags = *Flags; + } else { + TmpNode->Opcode = Instruction::Sub; + } + } else { + TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, + nullptr, nullptr); + TmpNode->Rotation = Rotation; + } + + TmpNode->addOperand(Result); + TmpNode->addOperand(AddNode); + submitCompositeNode(TmpNode); + Result = TmpNode; + RealAddends.erase(ItR); + ImagAddends.erase(ItI); + FoundImag = true; + break; + } + } + if (!FoundImag) + return nullptr; + } + return Result; +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::extractPositiveAddend( + std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { + for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { + for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { + auto [R, IsPositiveR] = *ItR; + auto [I, IsPositiveI] = *ItI; + if (IsPositiveR && IsPositiveI) { + auto Result = identifyNode(R, I); + if (Result) { + RealAddends.erase(ItR); + ImagAddends.erase(ItI); + return Result; + } + } + } + } + return nullptr; +} + +bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { + // This potential root instruction might already have been recognized as + // reduction. Because RootToNode maps both Real and Imaginary parts to + // CompositeNode we should choose only one either Real or Imag instruction to + // use as an anchor for generating complex instruction. + auto It = RootToNode.find(RootI); + if (It != RootToNode.end()) { + auto RootNode = It->second; + assert(RootNode->Operation == + ComplexDeinterleavingOperation::ReductionOperation); + // Find out which part, Real or Imag, comes later, and only if we come to + // the latest part, add it to OrderedRoots. + auto *R = cast<Instruction>(RootNode->Real); + auto *I = cast<Instruction>(RootNode->Imag); + auto *ReplacementAnchor = R->comesBefore(I) ? I : R; + if (ReplacementAnchor != RootI) + return false; + OrderedRoots.push_back(RootI); + return true; + } + + auto RootNode = identifyRoot(RootI); + if (!RootNode) + return false; + + LLVM_DEBUG({ + Function *F = RootI->getFunction(); + BasicBlock *B = RootI->getParent(); + dbgs() << "Complex deinterleaving graph for " << F->getName() + << "::" << B->getName() << ".\n"; + dump(dbgs()); + dbgs() << "\n"; + }); + RootToNode[RootI] = RootNode; + OrderedRoots.push_back(RootI); + return true; +} + +bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { + bool FoundPotentialReduction = false; + + auto *Br = dyn_cast<BranchInst>(B->getTerminator()); + if (!Br || Br->getNumSuccessors() != 2) + return false; + + // Identify simple one-block loop + if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) + return false; + + SmallVector<PHINode *> PHIs; + for (auto &PHI : B->phis()) { + if (PHI.getNumIncomingValues() != 2) + continue; + + if (!PHI.getType()->isVectorTy()) + continue; + + auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); + if (!ReductionOp) + continue; + + // Check if final instruction is reduced outside of current block + Instruction *FinalReduction = nullptr; + auto NumUsers = 0u; + for (auto *U : ReductionOp->users()) { + ++NumUsers; + if (U == &PHI) + continue; + FinalReduction = dyn_cast<Instruction>(U); + } + + if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B || + isa<PHINode>(FinalReduction)) + continue; + + ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; + BackEdge = B; + auto BackEdgeIdx = PHI.getBasicBlockIndex(B); + auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; + Incoming = PHI.getIncomingBlock(IncomingIdx); + FoundPotentialReduction = true; + + // If the initial value of PHINode is an Instruction, consider it a leaf + // value of a complex deinterleaving graph. + if (auto *InitPHI = + dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) + FinalInstructions.insert(InitPHI); + } + return FoundPotentialReduction; +} + +void ComplexDeinterleavingGraph::identifyReductionNodes() { + SmallVector<bool> Processed(ReductionInfo.size(), false); + SmallVector<Instruction *> OperationInstruction; + for (auto &P : ReductionInfo) + OperationInstruction.push_back(P.first); + + // Identify a complex computation by evaluating two reduction operations that + // potentially could be involved + for (size_t i = 0; i < OperationInstruction.size(); ++i) { + if (Processed[i]) + continue; + for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { + if (Processed[j]) + continue; + + auto *Real = OperationInstruction[i]; + auto *Imag = OperationInstruction[j]; + if (Real->getType() != Imag->getType()) + continue; + + RealPHI = ReductionInfo[Real].first; + ImagPHI = ReductionInfo[Imag].first; + PHIsFound = false; + auto Node = identifyNode(Real, Imag); + if (!Node) { + std::swap(Real, Imag); + std::swap(RealPHI, ImagPHI); + Node = identifyNode(Real, Imag); + } + + // If a node is identified and reduction PHINode is used in the chain of + // operations, mark its operation instructions as used to prevent + // re-identification and attach the node to the real part + if (Node && PHIsFound) { + LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " + << *Real << " / " << *Imag << "\n"); + Processed[i] = true; + Processed[j] = true; + auto RootNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); + RootNode->addOperand(Node); + RootToNode[Real] = RootNode; + RootToNode[Imag] = RootNode; + submitCompositeNode(RootNode); + break; + } + } + } + + RealPHI = nullptr; + ImagPHI = nullptr; +} + +bool ComplexDeinterleavingGraph::checkNodes() { + // Collect all instructions from roots to leaves + SmallPtrSet<Instruction *, 16> AllInstructions; + SmallVector<Instruction *, 8> Worklist; + for (auto &Pair : RootToNode) + Worklist.push_back(Pair.first); + + // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG + // chains + while (!Worklist.empty()) { + auto *I = Worklist.back(); + Worklist.pop_back(); + + if (!AllInstructions.insert(I).second) + continue; + + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast<Instruction>(Op)) { + if (!FinalInstructions.count(I)) + Worklist.emplace_back(OpI); + } + } + } + + // Find instructions that have users outside of chain + SmallVector<Instruction *, 2> OuterInstructions; + for (auto *I : AllInstructions) { + // Skip root nodes + if (RootToNode.count(I)) + continue; + + for (User *U : I->users()) { + if (AllInstructions.count(cast<Instruction>(U))) + continue; + + // Found an instruction that is not used by XCMLA/XCADD chain + Worklist.emplace_back(I); + break; + } + } + + // If any instructions are found to be used outside, find and remove roots + // that somehow connect to those instructions. + SmallPtrSet<Instruction *, 16> Visited; + while (!Worklist.empty()) { + auto *I = Worklist.back(); + Worklist.pop_back(); + if (!Visited.insert(I).second) + continue; + + // Found an impacted root node. Removing it from the nodes to be + // deinterleaved + if (RootToNode.count(I)) { + LLVM_DEBUG(dbgs() << "Instruction " << *I + << " could be deinterleaved but its chain of complex " + "operations have an outside user\n"); + RootToNode.erase(I); + } + + if (!AllInstructions.count(I) || FinalInstructions.count(I)) + continue; + + for (User *U : I->users()) + Worklist.emplace_back(cast<Instruction>(U)); + + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast<Instruction>(Op)) + Worklist.emplace_back(OpI); + } + } + return !RootToNode.empty(); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { + if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { + if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2) + return nullptr; + + auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); + auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); + if (!Real || !Imag) + return nullptr; + + return identifyNode(Real, Imag); + } + + auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); + if (!SVI) + return nullptr; + + // Look for a shufflevector that takes separate vectors of the real and + // imaginary components and recombines them into a single vector. + if (!isInterleavingMask(SVI->getShuffleMask())) + return nullptr; + + Instruction *Real; + Instruction *Imag; + if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) + return nullptr; + + return identifyNode(Real, Imag); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, + Instruction *Imag) { + Instruction *I = nullptr; + Value *FinalValue = nullptr; + if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && + match(Imag, m_ExtractValue<1>(m_Specific(I))) && + match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>( + m_Value(FinalValue)))) { + NodePtr PlaceholderNode = prepareCompositeNode( + llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); + PlaceholderNode->ReplacementNode = FinalValue; + FinalInstructions.insert(Real); + FinalInstructions.insert(Imag); + return submitCompositeNode(PlaceholderNode); + } + + auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); + auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); + if (!RealShuffle || !ImagShuffle) { + if (RealShuffle || ImagShuffle) + LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); + return nullptr; + } + + Value *RealOp1 = RealShuffle->getOperand(1); + if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { + LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); + return nullptr; + } + Value *ImagOp1 = ImagShuffle->getOperand(1); + if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { + LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); + return nullptr; + } + + Value *RealOp0 = RealShuffle->getOperand(0); + Value *ImagOp0 = ImagShuffle->getOperand(0); + + if (RealOp0 != ImagOp0) { + LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); + return nullptr; + } + + ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); + ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); + if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { + LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); + return nullptr; + } + + if (RealMask[0] != 0 || ImagMask[0] != 1) { + LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); + return nullptr; + } + + // Type checking, the shuffle type should be a vector type of the same + // scalar type, but half the size + auto CheckType = [&](ShuffleVectorInst *Shuffle) { + Value *Op = Shuffle->getOperand(0); + auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); + auto *OpTy = cast<FixedVectorType>(Op->getType()); + + if (OpTy->getScalarType() != ShuffleTy->getScalarType()) + return false; + if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) + return false; + + return true; + }; + + auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { + if (!CheckType(Shuffle)) + return false; + + ArrayRef<int> Mask = Shuffle->getShuffleMask(); + int Last = *Mask.rbegin(); + + Value *Op = Shuffle->getOperand(0); + auto *OpTy = cast<FixedVectorType>(Op->getType()); + int NumElements = OpTy->getNumElements(); + + // Ensure that the deinterleaving shuffle only pulls from the first + // shuffle operand. + return Last < NumElements; + }; + + if (RealShuffle->getType() != ImagShuffle->getType()) { + LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); + return nullptr; + } + if (!CheckDeinterleavingShuffle(RealShuffle)) { + LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); + return nullptr; + } + if (!CheckDeinterleavingShuffle(ImagShuffle)) { + LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); + return nullptr; + } + + NodePtr PlaceholderNode = + prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, + RealShuffle, ImagShuffle); + PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); + FinalInstructions.insert(RealShuffle); + FinalInstructions.insert(ImagShuffle); + return submitCompositeNode(PlaceholderNode); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { + auto IsSplat = [](Value *V) -> bool { + // Fixed-width vector with constants + if (isa<ConstantDataVector>(V)) + return true; + + VectorType *VTy; + ArrayRef<int> Mask; + // Splats are represented differently depending on whether the repeated + // value is a constant or an Instruction + if (auto *Const = dyn_cast<ConstantExpr>(V)) { + if (Const->getOpcode() != Instruction::ShuffleVector) + return false; + VTy = cast<VectorType>(Const->getType()); + Mask = Const->getShuffleMask(); + } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) { + VTy = Shuf->getType(); + Mask = Shuf->getShuffleMask(); + } else { + return false; + } + + // When the data type is <1 x Type>, it's not possible to differentiate + // between the ComplexDeinterleaving::Deinterleave and + // ComplexDeinterleaving::Splat operations. + if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) + return false; + + return all_equal(Mask) && Mask[0] == 0; + }; + + if (!IsSplat(R) || !IsSplat(I)) + return nullptr; + + auto *Real = dyn_cast<Instruction>(R); + auto *Imag = dyn_cast<Instruction>(I); + if ((!Real && Imag) || (Real && !Imag)) + return nullptr; + + if (Real && Imag) { + // Non-constant splats should be in the same basic block + if (Real->getParent() != Imag->getParent()) + return nullptr; + + FinalInstructions.insert(Real); + FinalInstructions.insert(Imag); + } + NodePtr PlaceholderNode = + prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); + return submitCompositeNode(PlaceholderNode); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, + Instruction *Imag) { + if (Real != RealPHI || Imag != ImagPHI) + return nullptr; + + PHIsFound = true; + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); + return submitCompositeNode(PlaceholderNode); +} + +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, + Instruction *Imag) { + auto *SelectReal = dyn_cast<SelectInst>(Real); + auto *SelectImag = dyn_cast<SelectInst>(Imag); + if (!SelectReal || !SelectImag) + return nullptr; + + Instruction *MaskA, *MaskB; + Instruction *AR, *AI, *RA, *BI; + if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), + m_Instruction(RA))) || + !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), + m_Instruction(BI)))) + return nullptr; + + if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) + return nullptr; + + if (!MaskA->getType()->isVectorTy()) + return nullptr; + + auto NodeA = identifyNode(AR, AI); + if (!NodeA) + return nullptr; + + auto NodeB = identifyNode(RA, BI); + if (!NodeB) + return nullptr; + + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); + PlaceholderNode->addOperand(NodeA); + PlaceholderNode->addOperand(NodeB); + FinalInstructions.insert(MaskA); + FinalInstructions.insert(MaskB); + return submitCompositeNode(PlaceholderNode); +} + +static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, + std::optional<FastMathFlags> Flags, + Value *InputA, Value *InputB) { + Value *I; + switch (Opcode) { + case Instruction::FNeg: + I = B.CreateFNeg(InputA); + break; + case Instruction::FAdd: + I = B.CreateFAdd(InputA, InputB); + break; + case Instruction::Add: + I = B.CreateAdd(InputA, InputB); + break; + case Instruction::FSub: + I = B.CreateFSub(InputA, InputB); + break; + case Instruction::Sub: + I = B.CreateSub(InputA, InputB); + break; + case Instruction::FMul: + I = B.CreateFMul(InputA, InputB); + break; + case Instruction::Mul: + I = B.CreateMul(InputA, InputB); + break; + default: + llvm_unreachable("Incorrect symmetric opcode"); + } + if (Flags) + cast<Instruction>(I)->setFastMathFlags(*Flags); + return I; +} + +Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, + RawNodePtr Node) { + if (Node->ReplacementNode) + return Node->ReplacementNode; + + auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { + return Node->Operands.size() > Idx + ? replaceNode(Builder, Node->Operands[Idx]) + : nullptr; + }; + + Value *ReplacementNode; + switch (Node->Operation) { + case ComplexDeinterleavingOperation::CAdd: + case ComplexDeinterleavingOperation::CMulPartial: + case ComplexDeinterleavingOperation::Symmetric: { + Value *Input0 = ReplaceOperandIfExist(Node, 0); + Value *Input1 = ReplaceOperandIfExist(Node, 1); + Value *Accumulator = ReplaceOperandIfExist(Node, 2); + assert(!Input1 || (Input0->getType() == Input1->getType() && + "Node inputs need to be of the same type")); + assert(!Accumulator || + (Input0->getType() == Accumulator->getType() && + "Accumulator and input need to be of the same type")); + if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) + ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, + Input0, Input1); + else + ReplacementNode = TL->createComplexDeinterleavingIR( + Builder, Node->Operation, Node->Rotation, Input0, Input1, + Accumulator); + break; + } + case ComplexDeinterleavingOperation::Deinterleave: + llvm_unreachable("Deinterleave node should already have ReplacementNode"); + break; + case ComplexDeinterleavingOperation::Splat: { + auto *NewTy = VectorType::getDoubleElementsVectorType( + cast<VectorType>(Node->Real->getType())); + auto *R = dyn_cast<Instruction>(Node->Real); + auto *I = dyn_cast<Instruction>(Node->Imag); + if (R && I) { + // Splats that are not constant are interleaved where they are located + Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); + IRBuilder<> IRB(InsertPoint); + ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2, + NewTy, {Node->Real, Node->Imag}); + } else { + ReplacementNode = Builder.CreateIntrinsic( + Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag}); + } + break; + } + case ComplexDeinterleavingOperation::ReductionPHI: { + // If Operation is ReductionPHI, a new empty PHINode is created. + // It is filled later when the ReductionOperation is processed. + auto *VTy = cast<VectorType>(Node->Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt()); + OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; + ReplacementNode = NewPHI; + break; + } + case ComplexDeinterleavingOperation::ReductionOperation: + ReplacementNode = replaceNode(Builder, Node->Operands[0]); + processReductionOperation(ReplacementNode, Node); + break; + case ComplexDeinterleavingOperation::ReductionSelect: { + auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0); + auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0); + auto *A = replaceNode(Builder, Node->Operands[0]); + auto *B = replaceNode(Builder, Node->Operands[1]); + auto *NewMaskTy = VectorType::getDoubleElementsVectorType( + cast<VectorType>(MaskReal->getType())); + auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, + NewMaskTy, {MaskReal, MaskImag}); + ReplacementNode = Builder.CreateSelect(NewMask, A, B); + break; + } + } + + assert(ReplacementNode && "Target failed to create Intrinsic call."); + NumComplexTransformations += 1; + Node->ReplacementNode = ReplacementNode; + return ReplacementNode; +} + +void ComplexDeinterleavingGraph::processReductionOperation( + Value *OperationReplacement, RawNodePtr Node) { + auto *Real = cast<Instruction>(Node->Real); + auto *Imag = cast<Instruction>(Node->Imag); + auto *OldPHIReal = ReductionInfo[Real].first; + auto *OldPHIImag = ReductionInfo[Imag].first; + auto *NewPHI = OldToNewPHI[OldPHIReal]; + + auto *VTy = cast<VectorType>(Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + + // We have to interleave initial origin values coming from IncomingBlock + Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); + Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); + + IRBuilder<> Builder(Incoming->getTerminator()); + auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy, + {InitReal, InitImag}); + + NewPHI->addIncoming(NewInit, Incoming); + NewPHI->addIncoming(OperationReplacement, BackEdge); + + // Deinterleave complex vector outside of loop so that it can be finally + // reduced + auto *FinalReductionReal = ReductionInfo[Real].second; + auto *FinalReductionImag = ReductionInfo[Imag].second; + + Builder.SetInsertPoint( + &*FinalReductionReal->getParent()->getFirstInsertionPt()); + auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2, + OperationReplacement->getType(), + OperationReplacement); + + auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); + FinalReductionReal->replaceUsesOfWith(Real, NewReal); + + Builder.SetInsertPoint(FinalReductionImag); + auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); + FinalReductionImag->replaceUsesOfWith(Imag, NewImag); +} + +void ComplexDeinterleavingGraph::replaceNodes() { + SmallVector<Instruction *, 16> DeadInstrRoots; + for (auto *RootInstruction : OrderedRoots) { + // Check if this potential root went through check process and we can + // deinterleave it + if (!RootToNode.count(RootInstruction)) + continue; + + IRBuilder<> Builder(RootInstruction); + auto RootNode = RootToNode[RootInstruction]; + Value *R = replaceNode(Builder, RootNode.get()); + + if (RootNode->Operation == + ComplexDeinterleavingOperation::ReductionOperation) { + auto *RootReal = cast<Instruction>(RootNode->Real); + auto *RootImag = cast<Instruction>(RootNode->Imag); + ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); + ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); + DeadInstrRoots.push_back(cast<Instruction>(RootReal)); + DeadInstrRoots.push_back(cast<Instruction>(RootImag)); + } else { + assert(R && "Unable to find replacement for RootInstruction"); + DeadInstrRoots.push_back(RootInstruction); + RootInstruction->replaceAllUsesWith(R); + } + } + + for (auto *I : DeadInstrRoots) + RecursivelyDeleteTriviallyDeadInstructions(I, TLI); +} |
