1//===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with "normalized" expressions.
10// See the comments at the top of ScalarEvolutionNormalization.h for details.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/ScalarEvolutionNormalization.h"
15#include "llvm/Analysis/LoopInfo.h"
16#include "llvm/Analysis/ScalarEvolution.h"
17#include "llvm/Analysis/ScalarEvolutionExpressions.h"
18using namespace llvm;
19
20/// TransformKind - Different types of transformations that
21/// TransformForPostIncUse can do.
22enum TransformKind {
23  /// Normalize - Normalize according to the given loops.
24  Normalize,
25  /// Denormalize - Perform the inverse transform on the expression with the
26  /// given loop set.
27  Denormalize
28};
29
30namespace {
31struct NormalizeDenormalizeRewriter
32    : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
33  const TransformKind Kind;
34
35  // NB! Pred is a function_ref.  Storing it here is okay only because
36  // we're careful about the lifetime of NormalizeDenormalizeRewriter.
37  const NormalizePredTy Pred;
38
39  NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
40                               ScalarEvolution &SE)
41      : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
42        Pred(Pred) {}
43  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
44};
45} // namespace
46
47const SCEV *
48NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
49  SmallVector<const SCEV *, 8> Operands;
50
51  transform(AR->operands(), std::back_inserter(Operands),
52            [&](const SCEV *Op) { return visit(Op); });
53
54  if (!Pred(AR))
55    return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
56
57  // Normalization and denormalization are fancy names for decrementing and
58  // incrementing a SCEV expression with respect to a set of loops.  Since
59  // Pred(AR) has returned true, we know we need to normalize or denormalize AR
60  // with respect to its loop.
61
62  if (Kind == Denormalize) {
63    // Denormalization / "partial increment" is essentially the same as \c
64    // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the
65    // symmetry with Normalization clear.
66    for (int i = 0, e = Operands.size() - 1; i < e; i++)
67      Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
68  } else {
69    assert(Kind == Normalize && "Only two possibilities!");
70
71    // Normalization / "partial decrement" is a bit more subtle.  Since
72    // incrementing a SCEV expression (in general) changes the step of the SCEV
73    // expression as well, we cannot use the step of the current expression.
74    // Instead, we have to use the step of the very expression we're trying to
75    // compute!
76    //
77    // We solve the issue by recursively building up the result, starting from
78    // the "least significant" operand in the add recurrence:
79    //
80    // Base case:
81    //   Single operand add recurrence.  It's its own normalization.
82    //
83    // N-operand case:
84    //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
85    //
86    //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
87    //   normalization by induction.  We subtract the normalized step
88    //   recurrence from S_{N-1} to get the normalization of S.
89
90    for (int i = Operands.size() - 2; i >= 0; i--)
91      Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
92  }
93
94  return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
95}
96
97const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
98                                         const PostIncLoopSet &Loops,
99                                         ScalarEvolution &SE) {
100  auto Pred = [&](const SCEVAddRecExpr *AR) {
101    return Loops.count(AR->getLoop());
102  };
103  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
104}
105
106const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
107                                           ScalarEvolution &SE) {
108  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
109}
110
111const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
112                                           const PostIncLoopSet &Loops,
113                                           ScalarEvolution &SE) {
114  auto Pred = [&](const SCEVAddRecExpr *AR) {
115    return Loops.count(AR->getLoop());
116  };
117  return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
118}
119