SyncDependenceAnalysis.cpp revision 360784
1//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
2//--===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements an algorithm that returns for a divergent branch
11// the set of basic blocks whose phi nodes become divergent due to divergent
12// control. These are the blocks that are reachable by two disjoint paths from
13// the branch or loop exits that have a reaching path that is disjoint from a
14// path to the loop latch.
15//
16// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model
17// control-induced divergence in phi nodes.
18//
19// -- Summary --
20// The SyncDependenceAnalysis lazily computes sync dependences [3].
21// The analysis evaluates the disjoint path criterion [2] by a reduction
22// to SSA construction. The SSA construction algorithm is implemented as
23// a simple data-flow analysis [1].
24//
25// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy
26// [2] "Efficiently Computing Static Single Assignment Form
27//     and the Control Dependence Graph", TOPLAS '91,
28//           Cytron, Ferrante, Rosen, Wegman and Zadeck
29// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack
30// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira
31//
32// -- Sync dependence --
33// Sync dependence [4] characterizes the control flow aspect of the
34// propagation of branch divergence. For example,
35//
36//   %cond = icmp slt i32 %tid, 10
37//   br i1 %cond, label %then, label %else
38// then:
39//   br label %merge
40// else:
41//   br label %merge
42// merge:
43//   %a = phi i32 [ 0, %then ], [ 1, %else ]
44//
45// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid
46// because %tid is not on its use-def chains, %a is sync dependent on %tid
47// because the branch "br i1 %cond" depends on %tid and affects which value %a
48// is assigned to.
49//
50// -- Reduction to SSA construction --
51// There are two disjoint paths from A to X, if a certain variant of SSA
52// construction places a phi node in X under the following set-up scheme [2].
53//
54// This variant of SSA construction ignores incoming undef values.
55// That is paths from the entry without a definition do not result in
56// phi nodes.
57//
58//       entry
59//     /      \
60//    A        \
61//  /   \       Y
62// B     C     /
63//  \   /  \  /
64//    D     E
65//     \   /
66//       F
67// Assume that A contains a divergent branch. We are interested
68// in the set of all blocks where each block is reachable from A
69// via two disjoint paths. This would be the set {D, F} in this
70// case.
71// To generally reduce this query to SSA construction we introduce
72// a virtual variable x and assign to x different values in each
73// successor block of A.
74//           entry
75//         /      \
76//        A        \
77//      /   \       Y
78// x = 0   x = 1   /
79//      \  /   \  /
80//        D     E
81//         \   /
82//           F
83// Our flavor of SSA construction for x will construct the following
84//            entry
85//          /      \
86//         A        \
87//       /   \       Y
88// x0 = 0   x1 = 1  /
89//       \   /   \ /
90//      x2=phi    E
91//         \     /
92//          x3=phi
93// The blocks D and F contain phi nodes and are thus each reachable
94// by two disjoins paths from A.
95//
96// -- Remarks --
97// In case of loop exits we need to check the disjoint path criterion for loops
98// [2]. To this end, we check whether the definition of x differs between the
99// loop exit and the loop header (_after_ SSA construction).
100//
101//===----------------------------------------------------------------------===//
102#include "llvm/ADT/PostOrderIterator.h"
103#include "llvm/ADT/SmallPtrSet.h"
104#include "llvm/Analysis/PostDominators.h"
105#include "llvm/Analysis/SyncDependenceAnalysis.h"
106#include "llvm/IR/BasicBlock.h"
107#include "llvm/IR/CFG.h"
108#include "llvm/IR/Dominators.h"
109#include "llvm/IR/Function.h"
110
111#include <stack>
112#include <unordered_set>
113
114#define DEBUG_TYPE "sync-dependence"
115
116namespace llvm {
117
118ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet;
119
120SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT,
121                                               const PostDominatorTree &PDT,
122                                               const LoopInfo &LI)
123    : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {}
124
125SyncDependenceAnalysis::~SyncDependenceAnalysis() {}
126
127using FunctionRPOT = ReversePostOrderTraversal<const Function *>;
128
129// divergence propagator for reducible CFGs
130struct DivergencePropagator {
131  const FunctionRPOT &FuncRPOT;
132  const DominatorTree &DT;
133  const PostDominatorTree &PDT;
134  const LoopInfo &LI;
135
136  // identified join points
137  std::unique_ptr<ConstBlockSet> JoinBlocks;
138
139  // reached loop exits (by a path disjoint to a path to the loop header)
140  SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits;
141
142  // if DefMap[B] == C then C is the dominating definition at block B
143  // if DefMap[B] ~ undef then we haven't seen B yet
144  // if DefMap[B] == B then B is a join point of disjoint paths from X or B is
145  // an immediate successor of X (initial value).
146  using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>;
147  DefiningBlockMap DefMap;
148
149  // all blocks with pending visits
150  std::unordered_set<const BasicBlock *> PendingUpdates;
151
152  DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT,
153                       const PostDominatorTree &PDT, const LoopInfo &LI)
154      : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI),
155        JoinBlocks(new ConstBlockSet) {}
156
157  // set the definition at @block and mark @block as pending for a visit
158  void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) {
159    bool WasAdded = DefMap.emplace(&Block, &DefBlock).second;
160    if (WasAdded)
161      PendingUpdates.insert(&Block);
162  }
163
164  void printDefs(raw_ostream &Out) {
165    Out << "Propagator::DefMap {\n";
166    for (const auto *Block : FuncRPOT) {
167      auto It = DefMap.find(Block);
168      Out << Block->getName() << " : ";
169      if (It == DefMap.end()) {
170        Out << "\n";
171      } else {
172        const auto *DefBlock = It->second;
173        Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n";
174      }
175    }
176    Out << "}\n";
177  }
178
179  // process @succBlock with reaching definition @defBlock
180  // the original divergent branch was in @parentLoop (if any)
181  void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop,
182                      const BasicBlock &DefBlock) {
183
184    // @succBlock is a loop exit
185    if (ParentLoop && !ParentLoop->contains(&SuccBlock)) {
186      DefMap.emplace(&SuccBlock, &DefBlock);
187      ReachedLoopExits.insert(&SuccBlock);
188      return;
189    }
190
191    // first reaching def?
192    auto ItLastDef = DefMap.find(&SuccBlock);
193    if (ItLastDef == DefMap.end()) {
194      addPending(SuccBlock, DefBlock);
195      return;
196    }
197
198    // a join of at least two definitions
199    if (ItLastDef->second != &DefBlock) {
200      // do we know this join already?
201      if (!JoinBlocks->insert(&SuccBlock).second)
202        return;
203
204      // update the definition
205      addPending(SuccBlock, SuccBlock);
206    }
207  }
208
209  // find all blocks reachable by two disjoint paths from @rootTerm.
210  // This method works for both divergent terminators and loops with
211  // divergent exits.
212  // @rootBlock is either the block containing the branch or the header of the
213  // divergent loop.
214  // @nodeSuccessors is the set of successors of the node (Loop or Terminator)
215  // headed by @rootBlock.
216  // @parentLoop is the parent loop of the Loop or the loop that contains the
217  // Terminator.
218  template <typename SuccessorIterable>
219  std::unique_ptr<ConstBlockSet>
220  computeJoinPoints(const BasicBlock &RootBlock,
221                    SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
222    assert(JoinBlocks);
223
224    LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " << (ParentLoop ? ParentLoop->getName() : "<null>") << "\n" );
225
226    // bootstrap with branch targets
227    for (const auto *SuccBlock : NodeSuccessors) {
228      DefMap.emplace(SuccBlock, SuccBlock);
229
230      if (ParentLoop && !ParentLoop->contains(SuccBlock)) {
231        // immediate loop exit from node.
232        ReachedLoopExits.insert(SuccBlock);
233      } else {
234        // regular successor
235        PendingUpdates.insert(SuccBlock);
236      }
237    }
238
239    LLVM_DEBUG(
240      dbgs() << "SDA: rpo order:\n";
241      for (const auto * RpoBlock : FuncRPOT) {
242        dbgs() << "- " << RpoBlock->getName() << "\n";
243      }
244    );
245
246    auto ItBeginRPO = FuncRPOT.begin();
247
248    // skip until term (TODO RPOT won't let us start at @term directly)
249    for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {}
250
251    auto ItEndRPO = FuncRPOT.end();
252    assert(ItBeginRPO != ItEndRPO);
253
254    // propagate definitions at the immediate successors of the node in RPO
255    auto ItBlockRPO = ItBeginRPO;
256    while ((++ItBlockRPO != ItEndRPO) &&
257           !PendingUpdates.empty()) {
258      const auto *Block = *ItBlockRPO;
259      LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n");
260
261      // skip Block if not pending update
262      auto ItPending = PendingUpdates.find(Block);
263      if (ItPending == PendingUpdates.end())
264        continue;
265      PendingUpdates.erase(ItPending);
266
267      // propagate definition at Block to its successors
268      auto ItDef = DefMap.find(Block);
269      const auto *DefBlock = ItDef->second;
270      assert(DefBlock);
271
272      auto *BlockLoop = LI.getLoopFor(Block);
273      if (ParentLoop &&
274          (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) {
275        // if the successor is the header of a nested loop pretend its a
276        // single node with the loop's exits as successors
277        SmallVector<BasicBlock *, 4> BlockLoopExits;
278        BlockLoop->getExitBlocks(BlockLoopExits);
279        for (const auto *BlockLoopExit : BlockLoopExits) {
280          visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock);
281        }
282
283      } else {
284        // the successors are either on the same loop level or loop exits
285        for (const auto *SuccBlock : successors(Block)) {
286          visitSuccessor(*SuccBlock, ParentLoop, *DefBlock);
287        }
288      }
289    }
290
291    LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs()));
292
293    // We need to know the definition at the parent loop header to decide
294    // whether the definition at the header is different from the definition at
295    // the loop exits, which would indicate a divergent loop exits.
296    //
297    // A // loop header
298    // |
299    // B // nested loop header
300    // |
301    // C -> X (exit from B loop) -..-> (A latch)
302    // |
303    // D -> back to B (B latch)
304    // |
305    // proper exit from both loops
306    //
307    // analyze reached loop exits
308    if (!ReachedLoopExits.empty()) {
309      const BasicBlock *ParentLoopHeader =
310          ParentLoop ? ParentLoop->getHeader() : nullptr;
311
312      assert(ParentLoop);
313      auto ItHeaderDef = DefMap.find(ParentLoopHeader);
314      const auto *HeaderDefBlock = (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second;
315
316      LLVM_DEBUG(printDefs(dbgs()));
317      assert(HeaderDefBlock && "no definition at header of carrying loop");
318
319      for (const auto *ExitBlock : ReachedLoopExits) {
320        auto ItExitDef = DefMap.find(ExitBlock);
321        assert((ItExitDef != DefMap.end()) &&
322               "no reaching def at reachable loop exit");
323        if (ItExitDef->second != HeaderDefBlock) {
324          JoinBlocks->insert(ExitBlock);
325        }
326      }
327    }
328
329    return std::move(JoinBlocks);
330  }
331};
332
333const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
334  using LoopExitVec = SmallVector<BasicBlock *, 4>;
335  LoopExitVec LoopExits;
336  Loop.getExitBlocks(LoopExits);
337  if (LoopExits.size() < 1) {
338    return EmptyBlockSet;
339  }
340
341  // already available in cache?
342  auto ItCached = CachedLoopExitJoins.find(&Loop);
343  if (ItCached != CachedLoopExitJoins.end()) {
344    return *ItCached->second;
345  }
346
347  // compute all join points
348  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
349  auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
350      *Loop.getHeader(), LoopExits, Loop.getParentLoop());
351
352  auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
353  assert(ItInserted.second);
354  return *ItInserted.first->second;
355}
356
357const ConstBlockSet &
358SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
359  // trivial case
360  if (Term.getNumSuccessors() < 1) {
361    return EmptyBlockSet;
362  }
363
364  // already available in cache?
365  auto ItCached = CachedBranchJoins.find(&Term);
366  if (ItCached != CachedBranchJoins.end())
367    return *ItCached->second;
368
369  // compute all join points
370  DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
371  const auto &TermBlock = *Term.getParent();
372  auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
373      TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
374
375  auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
376  assert(ItInserted.second);
377  return *ItInserted.first->second;
378}
379
380} // namespace llvm
381