diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 295 |
1 files changed, 192 insertions, 103 deletions
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 17594b98c5bc..f46ea6a20afa 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -72,6 +72,11 @@ static cl::opt<bool> AllowContractEnabled( cl::desc("Allow the use of FMAs if available and profitable. This may " "result in different results, due to less rounding error.")); +static cl::opt<bool> + VerifyShapeInfo("verify-matrix-shapes", cl::Hidden, + cl::desc("Enable/disable matrix shape verification."), + cl::init(false)); + enum class MatrixLayoutTy { ColumnMajor, RowMajor }; static cl::opt<MatrixLayoutTy> MatrixLayout( @@ -267,7 +272,7 @@ class LowerMatrixIntrinsics { unsigned D = isColumnMajor() ? NumColumns : NumRows; for (unsigned J = 0; J < D; ++J) - addVector(UndefValue::get(FixedVectorType::get( + addVector(PoisonValue::get(FixedVectorType::get( EltTy, isColumnMajor() ? NumRows : NumColumns))); } @@ -535,6 +540,15 @@ public: auto SIter = ShapeMap.find(V); if (SIter != ShapeMap.end()) { + if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows || + SIter->second.NumColumns != Shape.NumColumns)) { + errs() << "Conflicting shapes (" << SIter->second.NumRows << "x" + << SIter->second.NumColumns << " vs " << Shape.NumRows << "x" + << Shape.NumColumns << ") for " << *V << "\n"; + report_fatal_error( + "Matrix shape verification failed, compilation aborted!"); + } + LLVM_DEBUG(dbgs() << " not overriding existing shape: " << SIter->second.NumRows << " " << SIter->second.NumColumns << " for " << *V << "\n"); @@ -838,10 +852,13 @@ public: auto NewInst = distributeTransposes( TAMA, {R, C}, TAMB, {R, C}, Builder, [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { - auto *FAdd = - cast<Instruction>(LocalBuilder.CreateFAdd(T0, T1, "mfadd")); - setShapeInfo(FAdd, Shape0); - return FAdd; + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Add = IsFP ? LocalBuilder.CreateFAdd(T0, T1, "madd") + : LocalBuilder.CreateAdd(T0, T1, "madd"); + + auto *Result = cast<Instruction>(Add); + setShapeInfo(Result, Shape0); + return Result; }); updateShapeAndReplaceAllUsesWith(I, NewInst); eraseFromParentAndMove(&I, II, BB); @@ -978,13 +995,18 @@ public: MatrixInsts.push_back(&I); } - // Second, try to fuse candidates. + // Second, try to lower any dot products SmallPtrSet<Instruction *, 16> FusedInsts; for (CallInst *CI : MaybeFusableInsts) + lowerDotProduct(CI, FusedInsts, getFastMathFlags(CI)); + + // Third, try to fuse candidates. + for (CallInst *CI : MaybeFusableInsts) LowerMatrixMultiplyFused(CI, FusedInsts); + Changed = !FusedInsts.empty(); - // Third, lower remaining instructions with shape information. + // Fourth, lower remaining instructions with shape information. for (Instruction *Inst : MatrixInsts) { if (FusedInsts.count(Inst)) continue; @@ -1311,6 +1333,165 @@ public: } } + /// Special case for MatMul lowering. Prevents scalar loads of row-major + /// vectors Lowers to vector reduction add instead of sequential add if + /// reassocation is enabled. + void lowerDotProduct(CallInst *MatMul, + SmallPtrSet<Instruction *, 16> &FusedInsts, + FastMathFlags FMF) { + if (FusedInsts.contains(MatMul) || + MatrixLayout != MatrixLayoutTy::ColumnMajor) + return; + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product + return; + + Value *LHS = MatMul->getArgOperand(0); + Value *RHS = MatMul->getArgOperand(1); + + Type *ElementType = cast<VectorType>(LHS->getType())->getElementType(); + bool IsIntVec = ElementType->isIntegerTy(); + + // Floating point reductions require reassocation. + if (!IsIntVec && !FMF.allowReassoc()) + return; + + auto CanBeFlattened = [this](Value *Op) { + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + return true; + return match( + Op, m_OneUse(m_CombineOr( + m_Load(m_Value()), + m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(), + m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(), m_SpecificInt(1)))))); + }; + // Returns the cost benefit of using \p Op with the dot product lowering. If + // the returned cost is < 0, the argument is cheaper to use in the + // dot-product lowering. + auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { + if (!isa<Instruction>(Op)) + return InstructionCost(0); + + FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType()); + Type *EltTy = VecTy->getElementType(); + + if (!CanBeFlattened(Op)) { + InstructionCost EmbedCost(0); + // Roughly estimate the cost for embedding the columns into a vector. + for (unsigned I = 1; I < N; ++I) + EmbedCost -= + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; + } + + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + InstructionCost OriginalCost = + TTI.getArithmeticInstrCost(cast<Instruction>(Op)->getOpcode(), + EltTy) * + N; + InstructionCost NewCost = TTI.getArithmeticInstrCost( + cast<Instruction>(Op)->getOpcode(), VecTy); + return NewCost - OriginalCost; + } + + if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) { + // The transpose can be skipped for the dot product lowering, roughly + // estimate the savings as the cost of embedding the columns in a + // vector. + InstructionCost EmbedCost(0); + for (unsigned I = 1; I < N; ++I) + EmbedCost += + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; + } + + // Costs for loads. + if (N == 1) + return InstructionCost(0); + + return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - + N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); + }; + auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); + + // We compare the costs of a vector.reduce.add to sequential add. + int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; + int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul; + InstructionCost ReductionCost = + TTI.getArithmeticReductionCost( + AddOpCode, cast<VectorType>(LHS->getType()), + IsIntVec ? std::nullopt : std::optional(FMF)) + + TTI.getArithmeticInstrCost(MulOpCode, LHS->getType()); + InstructionCost SequentialAddCost = + TTI.getArithmeticInstrCost(AddOpCode, ElementType) * + (LShape.NumColumns - 1) + + TTI.getArithmeticInstrCost(MulOpCode, ElementType) * + (LShape.NumColumns); + if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0)) + return; + + FusedInsts.insert(MatMul); + IRBuilder<> Builder(MatMul); + auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, + this](Value *Op) -> Value * { + // Matmul must be the only user of loads because we don't use LowerLoad + // for row vectors (LowerLoad results in scalar loads and shufflevectors + // instead of single vector load). + if (!CanBeFlattened(Op)) + return Op; + + if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { + ShapeMap[Op] = ShapeMap[Op].t(); + return Op; + } + + FusedInsts.insert(cast<Instruction>(Op)); + // If vector uses the builtin load, lower to a LoadInst + Value *Arg; + if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(Arg)))) { + auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); + Op->replaceAllUsesWith(NewLoad); + cast<Instruction>(Op)->eraseFromParent(); + return NewLoad; + } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( + m_Value(Arg)))) { + ToRemove.push_back(cast<Instruction>(Op)); + return Arg; + } + + return Op; + }; + LHS = FlattenArg(LHS); + + // Insert mul/fmul and llvm.vector.reduce.fadd + Value *Mul = + IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(LHS, RHS); + + Value *Result; + if (IsIntVec) + Result = Builder.CreateAddReduce(Mul); + else { + Result = Builder.CreateFAddReduce( + ConstantFP::get(cast<VectorType>(LHS->getType())->getElementType(), + 0.0), + Mul); + cast<Instruction>(Result)->setFastMathFlags(FMF); + } + + // pack scalar back into a matrix and then replace matmul inst + Result = Builder.CreateInsertElement(PoisonValue::get(MatMul->getType()), + Result, uint64_t(0)); + MatMul->replaceAllUsesWith(Result); + FusedInsts.insert(MatMul); + ToRemove.push_back(MatMul); + } + /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. /// @@ -1469,15 +1650,14 @@ public: auto *ArrayTy = ArrayType::get(VT->getElementType(), VT->getNumElements()); AllocaInst *Alloca = Builder.CreateAlloca(ArrayTy, Load->getPointerAddressSpace()); - Value *BC = Builder.CreateBitCast(Alloca, VT->getPointerTo()); - Builder.CreateMemCpy(BC, Alloca->getAlign(), Load->getPointerOperand(), + Builder.CreateMemCpy(Alloca, Alloca->getAlign(), Load->getPointerOperand(), Load->getAlign(), LoadLoc.Size.getValue()); Builder.SetInsertPoint(Fusion, Fusion->begin()); PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); PHI->addIncoming(Load->getPointerOperand(), Check0); PHI->addIncoming(Load->getPointerOperand(), Check1); - PHI->addIncoming(BC, Copy); + PHI->addIncoming(Alloca, Copy); // Adjust DT. DTUpdates.push_back({DT->Insert, Check0, Check1}); @@ -2397,99 +2577,8 @@ void LowerMatrixIntrinsicsPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline( OS, MapClassName2PassName); - OS << "<"; + OS << '<'; if (Minimal) OS << "minimal"; - OS << ">"; -} - -namespace { - -class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { -public: - static char ID; - - LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { - initializeLowerMatrixIntrinsicsLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE); - bool C = LMT.Visit(); - return C; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - } -}; -} // namespace - -static const char pass_name[] = "Lower the matrix intrinsics"; -char LowerMatrixIntrinsicsLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, - false, false) -INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, - false, false) - -Pass *llvm::createLowerMatrixIntrinsicsPass() { - return new LowerMatrixIntrinsicsLegacyPass(); -} - -namespace { - -/// A lightweight version of the matrix lowering pass that only requires TTI. -/// Advanced features that require DT, AA or ORE like tiling are disabled. This -/// is used to lower matrix intrinsics if the main lowering pass is not run, for -/// example with -O0. -class LowerMatrixIntrinsicsMinimalLegacyPass : public FunctionPass { -public: - static char ID; - - LowerMatrixIntrinsicsMinimalLegacyPass() : FunctionPass(ID) { - initializeLowerMatrixIntrinsicsMinimalLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - LowerMatrixIntrinsics LMT(F, TTI, nullptr, nullptr, nullptr, nullptr); - bool C = LMT.Visit(); - return C; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.setPreservesCFG(); - } -}; -} // namespace - -static const char pass_name_minimal[] = "Lower the matrix intrinsics (minimal)"; -char LowerMatrixIntrinsicsMinimalLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsMinimalLegacyPass, - "lower-matrix-intrinsics-minimal", pass_name_minimal, - false, false) -INITIALIZE_PASS_END(LowerMatrixIntrinsicsMinimalLegacyPass, - "lower-matrix-intrinsics-minimal", pass_name_minimal, false, - false) - -Pass *llvm::createLowerMatrixIntrinsicsMinimalPass() { - return new LowerMatrixIntrinsicsMinimalLegacyPass(); + OS << '>'; } |