aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
blob: 36c91b7fa97e462b6000cb8598faef9b44aa737c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
//=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
// with vector operands) with matching calls to functions from a vector
// library (e.g., libmvec, SVML) according to TargetLibraryInfo.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/ReplaceWithVeclib.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/DemandedBits.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"

using namespace llvm;

#define DEBUG_TYPE "replace-with-veclib"

STATISTIC(NumCallsReplaced,
          "Number of calls to intrinsics that have been replaced.");

STATISTIC(NumTLIFuncDeclAdded,
          "Number of vector library function declarations added.");

STATISTIC(NumFuncUsedAdded,
          "Number of functions added to `llvm.compiler.used`");

static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
  Module *M = CI.getModule();

  Function *OldFunc = CI.getCalledFunction();

  // Check if the vector library function is already declared in this module,
  // otherwise insert it.
  Function *TLIFunc = M->getFunction(TLIName);
  if (!TLIFunc) {
    TLIFunc = Function::Create(OldFunc->getFunctionType(),
                               Function::ExternalLinkage, TLIName, *M);
    TLIFunc->copyAttributesFrom(OldFunc);

    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
                      << TLIName << "` of type `" << *(TLIFunc->getType())
                      << "` to module.\n");

    ++NumTLIFuncDeclAdded;

    // Add the freshly created function to llvm.compiler.used,
    // similar to as it is done in InjectTLIMappings
    appendToCompilerUsed(*M, {TLIFunc});

    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
                      << "` to `@llvm.compiler.used`.\n");
    ++NumFuncUsedAdded;
  }

  // Replace the call to the vector intrinsic with a call
  // to the corresponding function from the vector library.
  IRBuilder<> IRBuilder(&CI);
  SmallVector<Value *> Args(CI.args());
  // Preserve the operand bundles.
  SmallVector<OperandBundleDef, 1> OpBundles;
  CI.getOperandBundlesAsDefs(OpBundles);
  CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
  assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
         "Expecting function types to be identical");
  CI.replaceAllUsesWith(Replacement);
  if (isa<FPMathOperator>(Replacement)) {
    // Preserve fast math flags for FP math.
    Replacement->copyFastMathFlags(&CI);
  }

  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
                    << OldFunc->getName() << "` with call to `" << TLIName
                    << "`.\n");
  ++NumCallsReplaced;
  return true;
}

static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
                                    CallInst &CI) {
  if (!CI.getCalledFunction()) {
    return false;
  }

  auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
  if (IntrinsicID == Intrinsic::not_intrinsic) {
    // Replacement is only performed for intrinsic functions
    return false;
  }

  // Convert vector arguments to scalar type and check that
  // all vector operands have identical vector width.
  ElementCount VF = ElementCount::getFixed(0);
  SmallVector<Type *> ScalarTypes;
  for (auto Arg : enumerate(CI.args())) {
    auto *ArgType = Arg.value()->getType();
    // Vector calls to intrinsics can still have
    // scalar operands for specific arguments.
    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
      ScalarTypes.push_back(ArgType);
    } else {
      // The argument in this place should be a vector if
      // this is a call to a vector intrinsic.
      auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
      if (!VectorArgTy) {
        // The argument is not a vector, do not perform
        // the replacement.
        return false;
      }
      ElementCount NumElements = VectorArgTy->getElementCount();
      if (NumElements.isScalable()) {
        // The current implementation does not support
        // scalable vectors.
        return false;
      }
      if (VF.isNonZero() && VF != NumElements) {
        // The different arguments differ in vector size.
        return false;
      } else {
        VF = NumElements;
      }
      ScalarTypes.push_back(VectorArgTy->getElementType());
    }
  }

  // Try to reconstruct the name for the scalar version of this
  // intrinsic using the intrinsic ID and the argument types
  // converted to scalar above.
  std::string ScalarName;
  if (Intrinsic::isOverloaded(IntrinsicID)) {
    ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
  } else {
    ScalarName = Intrinsic::getName(IntrinsicID).str();
  }

  if (!TLI.isFunctionVectorizable(ScalarName)) {
    // The TargetLibraryInfo does not contain a vectorized version of
    // the scalar function.
    return false;
  }

  // Try to find the mapping for the scalar version of this intrinsic
  // and the exact vector width of the call operands in the
  // TargetLibraryInfo.
  StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF);

  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
                    << ScalarName << "` and vector width " << VF << ".\n");

  if (!TLIName.empty()) {
    // Found the correct mapping in the TargetLibraryInfo,
    // replace the call to the intrinsic with a call to
    // the vector library function.
    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
                      << "`.\n");
    return replaceWithTLIFunction(CI, TLIName);
  }

  return false;
}

static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
  bool Changed = false;
  SmallVector<CallInst *> ReplacedCalls;
  for (auto &I : instructions(F)) {
    if (auto *CI = dyn_cast<CallInst>(&I)) {
      if (replaceWithCallToVeclib(TLI, *CI)) {
        ReplacedCalls.push_back(CI);
        Changed = true;
      }
    }
  }
  // Erase the calls to the intrinsics that have been replaced
  // with calls to the vector library.
  for (auto *CI : ReplacedCalls) {
    CI->eraseFromParent();
  }
  return Changed;
}

////////////////////////////////////////////////////////////////////////////////
// New pass manager implementation.
////////////////////////////////////////////////////////////////////////////////
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
                                         FunctionAnalysisManager &AM) {
  const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
  auto Changed = runImpl(TLI, F);
  if (Changed) {
    PreservedAnalyses PA;
    PA.preserveSet<CFGAnalyses>();
    PA.preserve<TargetLibraryAnalysis>();
    PA.preserve<ScalarEvolutionAnalysis>();
    PA.preserve<LoopAccessAnalysis>();
    PA.preserve<DemandedBitsAnalysis>();
    PA.preserve<OptimizationRemarkEmitterAnalysis>();
    return PA;
  } else {
    // The pass did not replace any calls, hence it preserves all analyses.
    return PreservedAnalyses::all();
  }
}

////////////////////////////////////////////////////////////////////////////////
// Legacy PM Implementation.
////////////////////////////////////////////////////////////////////////////////
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
  const TargetLibraryInfo &TLI =
      getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
  return runImpl(TLI, F);
}

void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.setPreservesCFG();
  AU.addRequired<TargetLibraryInfoWrapperPass>();
  AU.addPreserved<TargetLibraryInfoWrapperPass>();
  AU.addPreserved<ScalarEvolutionWrapperPass>();
  AU.addPreserved<AAResultsWrapperPass>();
  AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
  AU.addPreserved<GlobalsAAWrapperPass>();
}

////////////////////////////////////////////////////////////////////////////////
// Legacy Pass manager initialization
////////////////////////////////////////////////////////////////////////////////
char ReplaceWithVeclibLegacy::ID = 0;

INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
                      "Replace intrinsics with calls to vector library", false,
                      false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
                    "Replace intrinsics with calls to vector library", false,
                    false)

FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
  return new ReplaceWithVeclibLegacy();
}