CFGMST.h revision 360784
1//===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a Union-find algorithm to compute Minimum Spanning Tree
10// for a given CFG.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
15#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
16
17#include "llvm/ADT/DenseMap.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Analysis/BlockFrequencyInfo.h"
20#include "llvm/Analysis/BranchProbabilityInfo.h"
21#include "llvm/Analysis/CFG.h"
22#include "llvm/Support/BranchProbability.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/raw_ostream.h"
25#include "llvm/Transforms/Utils/BasicBlockUtils.h"
26#include <utility>
27#include <vector>
28
29#define DEBUG_TYPE "cfgmst"
30
31namespace llvm {
32
33/// An union-find based Minimum Spanning Tree for CFG
34///
35/// Implements a Union-find algorithm to compute Minimum Spanning Tree
36/// for a given CFG.
37template <class Edge, class BBInfo> class CFGMST {
38public:
39  Function &F;
40
41  // Store all the edges in CFG. It may contain some stale edges
42  // when Removed is set.
43  std::vector<std::unique_ptr<Edge>> AllEdges;
44
45  // This map records the auxiliary information for each BB.
46  DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
47
48  // Whehter the function has an exit block with no successors.
49  // (For function with an infinite loop, this block may be absent)
50  bool ExitBlockFound = false;
51
52  // Find the root group of the G and compress the path from G to the root.
53  BBInfo *findAndCompressGroup(BBInfo *G) {
54    if (G->Group != G)
55      G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
56    return static_cast<BBInfo *>(G->Group);
57  }
58
59  // Union BB1 and BB2 into the same group and return true.
60  // Returns false if BB1 and BB2 are already in the same group.
61  bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
62    BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
63    BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
64
65    if (BB1G == BB2G)
66      return false;
67
68    // Make the smaller rank tree a direct child or the root of high rank tree.
69    if (BB1G->Rank < BB2G->Rank)
70      BB1G->Group = BB2G;
71    else {
72      BB2G->Group = BB1G;
73      // If the ranks are the same, increment root of one tree by one.
74      if (BB1G->Rank == BB2G->Rank)
75        BB1G->Rank++;
76    }
77    return true;
78  }
79
80  // Give BB, return the auxiliary information.
81  BBInfo &getBBInfo(const BasicBlock *BB) const {
82    auto It = BBInfos.find(BB);
83    assert(It->second.get() != nullptr);
84    return *It->second.get();
85  }
86
87  // Give BB, return the auxiliary information if it's available.
88  BBInfo *findBBInfo(const BasicBlock *BB) const {
89    auto It = BBInfos.find(BB);
90    if (It == BBInfos.end())
91      return nullptr;
92    return It->second.get();
93  }
94
95  // Traverse the CFG using a stack. Find all the edges and assign the weight.
96  // Edges with large weight will be put into MST first so they are less likely
97  // to be instrumented.
98  void buildEdges() {
99    LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
100
101    const BasicBlock *Entry = &(F.getEntryBlock());
102    uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2);
103    Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
104        *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
105    uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
106
107    // Add a fake edge to the entry.
108    EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
109    LLVM_DEBUG(dbgs() << "  Edge: from fake node to " << Entry->getName()
110                      << " w = " << EntryWeight << "\n");
111
112    // Special handling for single BB functions.
113    if (succ_empty(Entry)) {
114      addEdge(Entry, nullptr, EntryWeight);
115      return;
116    }
117
118    static const uint32_t CriticalEdgeMultiplier = 1000;
119
120    for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
121      Instruction *TI = BB->getTerminator();
122      uint64_t BBWeight =
123          (BFI != nullptr ? BFI->getBlockFreq(&*BB).getFrequency() : 2);
124      uint64_t Weight = 2;
125      if (int successors = TI->getNumSuccessors()) {
126        for (int i = 0; i != successors; ++i) {
127          BasicBlock *TargetBB = TI->getSuccessor(i);
128          bool Critical = isCriticalEdge(TI, i);
129          uint64_t scaleFactor = BBWeight;
130          if (Critical) {
131            if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
132              scaleFactor *= CriticalEdgeMultiplier;
133            else
134              scaleFactor = UINT64_MAX;
135          }
136          if (BPI != nullptr)
137            Weight = BPI->getEdgeProbability(&*BB, TargetBB).scale(scaleFactor);
138          auto *E = &addEdge(&*BB, TargetBB, Weight);
139          E->IsCritical = Critical;
140          LLVM_DEBUG(dbgs() << "  Edge: from " << BB->getName() << " to "
141                            << TargetBB->getName() << "  w=" << Weight << "\n");
142
143          // Keep track of entry/exit edges:
144          if (&*BB == Entry) {
145            if (Weight > MaxEntryOutWeight) {
146              MaxEntryOutWeight = Weight;
147              EntryOutgoing = E;
148            }
149          }
150
151          auto *TargetTI = TargetBB->getTerminator();
152          if (TargetTI && !TargetTI->getNumSuccessors()) {
153            if (Weight > MaxExitInWeight) {
154              MaxExitInWeight = Weight;
155              ExitIncoming = E;
156            }
157          }
158        }
159      } else {
160        ExitBlockFound = true;
161        Edge *ExitO = &addEdge(&*BB, nullptr, BBWeight);
162        if (BBWeight > MaxExitOutWeight) {
163          MaxExitOutWeight = BBWeight;
164          ExitOutgoing = ExitO;
165        }
166        LLVM_DEBUG(dbgs() << "  Edge: from " << BB->getName() << " to fake exit"
167                          << " w = " << BBWeight << "\n");
168      }
169    }
170
171    // Entry/exit edge adjustment heurisitic:
172    // prefer instrumenting entry edge over exit edge
173    // if possible. Those exit edges may never have a chance to be
174    // executed (for instance the program is an event handling loop)
175    // before the profile is asynchronously dumped.
176    //
177    // If EntryIncoming and ExitOutgoing has similar weight, make sure
178    // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
179    // and ExitIncoming has similar weight, make sure ExitIncoming becomes
180    // the min-edge.
181    uint64_t EntryInWeight = EntryWeight;
182
183    if (EntryInWeight >= MaxExitOutWeight &&
184        EntryInWeight * 2 < MaxExitOutWeight * 3) {
185      EntryIncoming->Weight = MaxExitOutWeight;
186      ExitOutgoing->Weight = EntryInWeight + 1;
187    }
188
189    if (MaxEntryOutWeight >= MaxExitInWeight &&
190        MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
191      EntryOutgoing->Weight = MaxExitInWeight;
192      ExitIncoming->Weight = MaxEntryOutWeight + 1;
193    }
194  }
195
196  // Sort CFG edges based on its weight.
197  void sortEdgesByWeight() {
198    llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
199                                   const std::unique_ptr<Edge> &Edge2) {
200      return Edge1->Weight > Edge2->Weight;
201    });
202  }
203
204  // Traverse all the edges and compute the Minimum Weight Spanning Tree
205  // using union-find algorithm.
206  void computeMinimumSpanningTree() {
207    // First, put all the critical edge with landing-pad as the Dest to MST.
208    // This works around the insufficient support of critical edges split
209    // when destination BB is a landing pad.
210    for (auto &Ei : AllEdges) {
211      if (Ei->Removed)
212        continue;
213      if (Ei->IsCritical) {
214        if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
215          if (unionGroups(Ei->SrcBB, Ei->DestBB))
216            Ei->InMST = true;
217        }
218      }
219    }
220
221    for (auto &Ei : AllEdges) {
222      if (Ei->Removed)
223        continue;
224      // If we detect infinite loops, force
225      // instrumenting the entry edge:
226      if (!ExitBlockFound && Ei->SrcBB == nullptr)
227        continue;
228      if (unionGroups(Ei->SrcBB, Ei->DestBB))
229        Ei->InMST = true;
230    }
231  }
232
233  // Dump the Debug information about the instrumentation.
234  void dumpEdges(raw_ostream &OS, const Twine &Message) const {
235    if (!Message.str().empty())
236      OS << Message << "\n";
237    OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
238    for (auto &BI : BBInfos) {
239      const BasicBlock *BB = BI.first;
240      OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
241         << BI.second->infoString() << "\n";
242    }
243
244    OS << "  Number of Edges: " << AllEdges.size()
245       << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
246    uint32_t Count = 0;
247    for (auto &EI : AllEdges)
248      OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
249         << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
250  }
251
252  // Add an edge to AllEdges with weight W.
253  Edge &addEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W) {
254    uint32_t Index = BBInfos.size();
255    auto Iter = BBInfos.end();
256    bool Inserted;
257    std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
258    if (Inserted) {
259      // Newly inserted, update the real info.
260      Iter->second = std::move(std::make_unique<BBInfo>(Index));
261      Index++;
262    }
263    std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
264    if (Inserted)
265      // Newly inserted, update the real info.
266      Iter->second = std::move(std::make_unique<BBInfo>(Index));
267    AllEdges.emplace_back(new Edge(Src, Dest, W));
268    return *AllEdges.back();
269  }
270
271  BranchProbabilityInfo *BPI;
272  BlockFrequencyInfo *BFI;
273
274public:
275  CFGMST(Function &Func, BranchProbabilityInfo *BPI_ = nullptr,
276         BlockFrequencyInfo *BFI_ = nullptr)
277      : F(Func), BPI(BPI_), BFI(BFI_) {
278    buildEdges();
279    sortEdgesByWeight();
280    computeMinimumSpanningTree();
281  }
282};
283
284} // end namespace llvm
285
286#undef DEBUG_TYPE // "cfgmst"
287
288#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_CFGMST_H
289