diff options
Diffstat (limited to 'contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp')
| -rw-r--r-- | contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp b/contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp index 0851d1993d0c..02c86ad2e58c 100644 --- a/contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp +++ b/contrib/llvm-project/clang/lib/CodeGen/Targets/RISCV.cpp @@ -321,20 +321,28 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty) const { assert(Ty->isVectorType() && "expected vector type!"); const auto *VT = Ty->castAs<VectorType>(); - assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData && - "Unexpected vector kind"); - assert(VT->getElementType()->isBuiltinType() && "expected builtin type!"); auto VScale = getContext().getTargetInfo().getVScaleRange(getContext().getLangOpts()); + + unsigned NumElts = VT->getNumElements(); + llvm::Type *EltType; + if (VT->getVectorKind() == VectorKind::RVVFixedLengthMask) { + NumElts *= 8; + EltType = llvm::Type::getInt1Ty(getVMContext()); + } else { + assert(VT->getVectorKind() == VectorKind::RVVFixedLengthData && + "Unexpected vector kind"); + EltType = CGT.ConvertType(VT->getElementType()); + } + // The MinNumElts is simplified from equation: // NumElts / VScale = // (EltSize * NumElts / (VScale * RVVBitsPerBlock)) // * (RVVBitsPerBlock / EltSize) llvm::ScalableVectorType *ResType = - llvm::ScalableVectorType::get(CGT.ConvertType(VT->getElementType()), - VT->getNumElements() / VScale->first); + llvm::ScalableVectorType::get(EltType, NumElts / VScale->first); return ABIArgInfo::getDirect(ResType); } @@ -437,7 +445,8 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed, } if (const VectorType *VT = Ty->getAs<VectorType>()) - if (VT->getVectorKind() == VectorKind::RVVFixedLengthData) + if (VT->getVectorKind() == VectorKind::RVVFixedLengthData || + VT->getVectorKind() == VectorKind::RVVFixedLengthMask) return coerceVLSVector(Ty); // Aggregates which are <= 2*XLen will be passed in registers if possible, |
