diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolutionNormalization.cpp')
| -rw-r--r-- | llvm/lib/Analysis/ScalarEvolutionNormalization.cpp | 117 | 
1 files changed, 117 insertions, 0 deletions
| diff --git a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp new file mode 100644 index 000000000000..209ae66ca53e --- /dev/null +++ b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -0,0 +1,117 @@ +//===- ScalarEvolutionNormalization.cpp - See below -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for working with "normalized" expressions. +// See the comments at the top of ScalarEvolutionNormalization.h for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionNormalization.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +using namespace llvm; + +/// TransformKind - Different types of transformations that +/// TransformForPostIncUse can do. +enum TransformKind { +  /// Normalize - Normalize according to the given loops. +  Normalize, +  /// Denormalize - Perform the inverse transform on the expression with the +  /// given loop set. +  Denormalize +}; + +namespace { +struct NormalizeDenormalizeRewriter +    : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { +  const TransformKind Kind; + +  // NB! Pred is a function_ref.  Storing it here is okay only because +  // we're careful about the lifetime of NormalizeDenormalizeRewriter. +  const NormalizePredTy Pred; + +  NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, +                               ScalarEvolution &SE) +      : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), +        Pred(Pred) {} +  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); +}; +} // namespace + +const SCEV * +NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { +  SmallVector<const SCEV *, 8> Operands; + +  transform(AR->operands(), std::back_inserter(Operands), +            [&](const SCEV *Op) { return visit(Op); }); + +  if (!Pred(AR)) +    return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); + +  // Normalization and denormalization are fancy names for decrementing and +  // incrementing a SCEV expression with respect to a set of loops.  Since +  // Pred(AR) has returned true, we know we need to normalize or denormalize AR +  // with respect to its loop. + +  if (Kind == Denormalize) { +    // Denormalization / "partial increment" is essentially the same as \c +    // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the +    // symmetry with Normalization clear. +    for (int i = 0, e = Operands.size() - 1; i < e; i++) +      Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); +  } else { +    assert(Kind == Normalize && "Only two possibilities!"); + +    // Normalization / "partial decrement" is a bit more subtle.  Since +    // incrementing a SCEV expression (in general) changes the step of the SCEV +    // expression as well, we cannot use the step of the current expression. +    // Instead, we have to use the step of the very expression we're trying to +    // compute! +    // +    // We solve the issue by recursively building up the result, starting from +    // the "least significant" operand in the add recurrence: +    // +    // Base case: +    //   Single operand add recurrence.  It's its own normalization. +    // +    // N-operand case: +    //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S +    // +    //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its +    //   normalization by induction.  We subtract the normalized step +    //   recurrence from S_{N-1} to get the normalization of S. + +    for (int i = Operands.size() - 2; i >= 0; i--) +      Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); +  } + +  return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); +} + +const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, +                                         const PostIncLoopSet &Loops, +                                         ScalarEvolution &SE) { +  auto Pred = [&](const SCEVAddRecExpr *AR) { +    return Loops.count(AR->getLoop()); +  }; +  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, +                                           ScalarEvolution &SE) { +  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, +                                           const PostIncLoopSet &Loops, +                                           ScalarEvolution &SE) { +  auto Pred = [&](const SCEVAddRecExpr *AR) { +    return Loops.count(AR->getLoop()); +  }; +  return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); +} | 
