CodeGenPGO.cpp revision 360784
1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CodeGenFunction.h"
15#include "CoverageMappingGen.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Endian.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/MD5.h"
24
25static llvm::cl::opt<bool>
26    EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                         llvm::cl::desc("Enable value profiling"),
28                         llvm::cl::Hidden, llvm::cl::init(false));
29
30using namespace clang;
31using namespace CodeGen;
32
33void CodeGenPGO::setFuncName(StringRef Name,
34                             llvm::GlobalValue::LinkageTypes Linkage) {
35  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36  FuncName = llvm::getPGOFuncName(
37      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38      PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40  // If we're generating a profile, create a variable for the name.
41  if (CGM.getCodeGenOpts().hasProfileClangInstr())
42    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43}
44
45void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46  setFuncName(Fn->getName(), Fn->getLinkage());
47  // Create PGOFuncName meta data.
48  llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49}
50
51/// The version of the PGO hash algorithm.
52enum PGOHashVersion : unsigned {
53  PGO_HASH_V1,
54  PGO_HASH_V2,
55
56  // Keep this set to the latest hash version.
57  PGO_HASH_LATEST = PGO_HASH_V2
58};
59
60namespace {
61/// Stable hasher for PGO region counters.
62///
63/// PGOHash produces a stable hash of a given function's control flow.
64///
65/// Changing the output of this hash will invalidate all previously generated
66/// profiles -- i.e., don't do it.
67///
68/// \note  When this hash does eventually change (years?), we still need to
69/// support old hashes.  We'll need to pull in the version number from the
70/// profile data format and use the matching hash function.
71class PGOHash {
72  uint64_t Working;
73  unsigned Count;
74  PGOHashVersion HashVersion;
75  llvm::MD5 MD5;
76
77  static const int NumBitsPerType = 6;
78  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
79  static const unsigned TooBig = 1u << NumBitsPerType;
80
81public:
82  /// Hash values for AST nodes.
83  ///
84  /// Distinct values for AST nodes that have region counters attached.
85  ///
86  /// These values must be stable.  All new members must be added at the end,
87  /// and no members should be removed.  Changing the enumeration value for an
88  /// AST node will affect the hash of every function that contains that node.
89  enum HashType : unsigned char {
90    None = 0,
91    LabelStmt = 1,
92    WhileStmt,
93    DoStmt,
94    ForStmt,
95    CXXForRangeStmt,
96    ObjCForCollectionStmt,
97    SwitchStmt,
98    CaseStmt,
99    DefaultStmt,
100    IfStmt,
101    CXXTryStmt,
102    CXXCatchStmt,
103    ConditionalOperator,
104    BinaryOperatorLAnd,
105    BinaryOperatorLOr,
106    BinaryConditionalOperator,
107    // The preceding values are available with PGO_HASH_V1.
108
109    EndOfScope,
110    IfThenBranch,
111    IfElseBranch,
112    GotoStmt,
113    IndirectGotoStmt,
114    BreakStmt,
115    ContinueStmt,
116    ReturnStmt,
117    ThrowExpr,
118    UnaryOperatorLNot,
119    BinaryOperatorLT,
120    BinaryOperatorGT,
121    BinaryOperatorLE,
122    BinaryOperatorGE,
123    BinaryOperatorEQ,
124    BinaryOperatorNE,
125    // The preceding values are available with PGO_HASH_V2.
126
127    // Keep this last.  It's for the static assert that follows.
128    LastHashType
129  };
130  static_assert(LastHashType <= TooBig, "Too many types in HashType");
131
132  PGOHash(PGOHashVersion HashVersion)
133      : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
134  void combine(HashType Type);
135  uint64_t finalize();
136  PGOHashVersion getHashVersion() const { return HashVersion; }
137};
138const int PGOHash::NumBitsPerType;
139const unsigned PGOHash::NumTypesPerWord;
140const unsigned PGOHash::TooBig;
141
142/// Get the PGO hash version used in the given indexed profile.
143static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
144                                        CodeGenModule &CGM) {
145  if (PGOReader->getVersion() <= 4)
146    return PGO_HASH_V1;
147  return PGO_HASH_V2;
148}
149
150/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
151struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
152  using Base = RecursiveASTVisitor<MapRegionCounters>;
153
154  /// The next counter value to assign.
155  unsigned NextCounter;
156  /// The function hash.
157  PGOHash Hash;
158  /// The map of statements to counters.
159  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
160
161  MapRegionCounters(PGOHashVersion HashVersion,
162                    llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
163      : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
164
165  // Blocks and lambdas are handled as separate functions, so we need not
166  // traverse them in the parent context.
167  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
168  bool TraverseLambdaExpr(LambdaExpr *LE) {
169    // Traverse the captures, but not the body.
170    for (auto C : zip(LE->captures(), LE->capture_inits()))
171      TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
172    return true;
173  }
174  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
175
176  bool VisitDecl(const Decl *D) {
177    switch (D->getKind()) {
178    default:
179      break;
180    case Decl::Function:
181    case Decl::CXXMethod:
182    case Decl::CXXConstructor:
183    case Decl::CXXDestructor:
184    case Decl::CXXConversion:
185    case Decl::ObjCMethod:
186    case Decl::Block:
187    case Decl::Captured:
188      CounterMap[D->getBody()] = NextCounter++;
189      break;
190    }
191    return true;
192  }
193
194  /// If \p S gets a fresh counter, update the counter mappings. Return the
195  /// V1 hash of \p S.
196  PGOHash::HashType updateCounterMappings(Stmt *S) {
197    auto Type = getHashType(PGO_HASH_V1, S);
198    if (Type != PGOHash::None)
199      CounterMap[S] = NextCounter++;
200    return Type;
201  }
202
203  /// Include \p S in the function hash.
204  bool VisitStmt(Stmt *S) {
205    auto Type = updateCounterMappings(S);
206    if (Hash.getHashVersion() != PGO_HASH_V1)
207      Type = getHashType(Hash.getHashVersion(), S);
208    if (Type != PGOHash::None)
209      Hash.combine(Type);
210    return true;
211  }
212
213  bool TraverseIfStmt(IfStmt *If) {
214    // If we used the V1 hash, use the default traversal.
215    if (Hash.getHashVersion() == PGO_HASH_V1)
216      return Base::TraverseIfStmt(If);
217
218    // Otherwise, keep track of which branch we're in while traversing.
219    VisitStmt(If);
220    for (Stmt *CS : If->children()) {
221      if (!CS)
222        continue;
223      if (CS == If->getThen())
224        Hash.combine(PGOHash::IfThenBranch);
225      else if (CS == If->getElse())
226        Hash.combine(PGOHash::IfElseBranch);
227      TraverseStmt(CS);
228    }
229    Hash.combine(PGOHash::EndOfScope);
230    return true;
231  }
232
233// If the statement type \p N is nestable, and its nesting impacts profile
234// stability, define a custom traversal which tracks the end of the statement
235// in the hash (provided we're not using the V1 hash).
236#define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
237  bool Traverse##N(N *S) {                                                     \
238    Base::Traverse##N(S);                                                      \
239    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
240      Hash.combine(PGOHash::EndOfScope);                                       \
241    return true;                                                               \
242  }
243
244  DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
245  DEFINE_NESTABLE_TRAVERSAL(DoStmt)
246  DEFINE_NESTABLE_TRAVERSAL(ForStmt)
247  DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
248  DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
249  DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
250  DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
251
252  /// Get version \p HashVersion of the PGO hash for \p S.
253  PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
254    switch (S->getStmtClass()) {
255    default:
256      break;
257    case Stmt::LabelStmtClass:
258      return PGOHash::LabelStmt;
259    case Stmt::WhileStmtClass:
260      return PGOHash::WhileStmt;
261    case Stmt::DoStmtClass:
262      return PGOHash::DoStmt;
263    case Stmt::ForStmtClass:
264      return PGOHash::ForStmt;
265    case Stmt::CXXForRangeStmtClass:
266      return PGOHash::CXXForRangeStmt;
267    case Stmt::ObjCForCollectionStmtClass:
268      return PGOHash::ObjCForCollectionStmt;
269    case Stmt::SwitchStmtClass:
270      return PGOHash::SwitchStmt;
271    case Stmt::CaseStmtClass:
272      return PGOHash::CaseStmt;
273    case Stmt::DefaultStmtClass:
274      return PGOHash::DefaultStmt;
275    case Stmt::IfStmtClass:
276      return PGOHash::IfStmt;
277    case Stmt::CXXTryStmtClass:
278      return PGOHash::CXXTryStmt;
279    case Stmt::CXXCatchStmtClass:
280      return PGOHash::CXXCatchStmt;
281    case Stmt::ConditionalOperatorClass:
282      return PGOHash::ConditionalOperator;
283    case Stmt::BinaryConditionalOperatorClass:
284      return PGOHash::BinaryConditionalOperator;
285    case Stmt::BinaryOperatorClass: {
286      const BinaryOperator *BO = cast<BinaryOperator>(S);
287      if (BO->getOpcode() == BO_LAnd)
288        return PGOHash::BinaryOperatorLAnd;
289      if (BO->getOpcode() == BO_LOr)
290        return PGOHash::BinaryOperatorLOr;
291      if (HashVersion == PGO_HASH_V2) {
292        switch (BO->getOpcode()) {
293        default:
294          break;
295        case BO_LT:
296          return PGOHash::BinaryOperatorLT;
297        case BO_GT:
298          return PGOHash::BinaryOperatorGT;
299        case BO_LE:
300          return PGOHash::BinaryOperatorLE;
301        case BO_GE:
302          return PGOHash::BinaryOperatorGE;
303        case BO_EQ:
304          return PGOHash::BinaryOperatorEQ;
305        case BO_NE:
306          return PGOHash::BinaryOperatorNE;
307        }
308      }
309      break;
310    }
311    }
312
313    if (HashVersion == PGO_HASH_V2) {
314      switch (S->getStmtClass()) {
315      default:
316        break;
317      case Stmt::GotoStmtClass:
318        return PGOHash::GotoStmt;
319      case Stmt::IndirectGotoStmtClass:
320        return PGOHash::IndirectGotoStmt;
321      case Stmt::BreakStmtClass:
322        return PGOHash::BreakStmt;
323      case Stmt::ContinueStmtClass:
324        return PGOHash::ContinueStmt;
325      case Stmt::ReturnStmtClass:
326        return PGOHash::ReturnStmt;
327      case Stmt::CXXThrowExprClass:
328        return PGOHash::ThrowExpr;
329      case Stmt::UnaryOperatorClass: {
330        const UnaryOperator *UO = cast<UnaryOperator>(S);
331        if (UO->getOpcode() == UO_LNot)
332          return PGOHash::UnaryOperatorLNot;
333        break;
334      }
335      }
336    }
337
338    return PGOHash::None;
339  }
340};
341
342/// A StmtVisitor that propagates the raw counts through the AST and
343/// records the count at statements where the value may change.
344struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
345  /// PGO state.
346  CodeGenPGO &PGO;
347
348  /// A flag that is set when the current count should be recorded on the
349  /// next statement, such as at the exit of a loop.
350  bool RecordNextStmtCount;
351
352  /// The count at the current location in the traversal.
353  uint64_t CurrentCount;
354
355  /// The map of statements to count values.
356  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
357
358  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
359  struct BreakContinue {
360    uint64_t BreakCount;
361    uint64_t ContinueCount;
362    BreakContinue() : BreakCount(0), ContinueCount(0) {}
363  };
364  SmallVector<BreakContinue, 8> BreakContinueStack;
365
366  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
367                      CodeGenPGO &PGO)
368      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
369
370  void RecordStmtCount(const Stmt *S) {
371    if (RecordNextStmtCount) {
372      CountMap[S] = CurrentCount;
373      RecordNextStmtCount = false;
374    }
375  }
376
377  /// Set and return the current count.
378  uint64_t setCount(uint64_t Count) {
379    CurrentCount = Count;
380    return Count;
381  }
382
383  void VisitStmt(const Stmt *S) {
384    RecordStmtCount(S);
385    for (const Stmt *Child : S->children())
386      if (Child)
387        this->Visit(Child);
388  }
389
390  void VisitFunctionDecl(const FunctionDecl *D) {
391    // Counter tracks entry to the function body.
392    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
393    CountMap[D->getBody()] = BodyCount;
394    Visit(D->getBody());
395  }
396
397  // Skip lambda expressions. We visit these as FunctionDecls when we're
398  // generating them and aren't interested in the body when generating a
399  // parent context.
400  void VisitLambdaExpr(const LambdaExpr *LE) {}
401
402  void VisitCapturedDecl(const CapturedDecl *D) {
403    // Counter tracks entry to the capture body.
404    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
405    CountMap[D->getBody()] = BodyCount;
406    Visit(D->getBody());
407  }
408
409  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
410    // Counter tracks entry to the method body.
411    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
412    CountMap[D->getBody()] = BodyCount;
413    Visit(D->getBody());
414  }
415
416  void VisitBlockDecl(const BlockDecl *D) {
417    // Counter tracks entry to the block body.
418    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
419    CountMap[D->getBody()] = BodyCount;
420    Visit(D->getBody());
421  }
422
423  void VisitReturnStmt(const ReturnStmt *S) {
424    RecordStmtCount(S);
425    if (S->getRetValue())
426      Visit(S->getRetValue());
427    CurrentCount = 0;
428    RecordNextStmtCount = true;
429  }
430
431  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
432    RecordStmtCount(E);
433    if (E->getSubExpr())
434      Visit(E->getSubExpr());
435    CurrentCount = 0;
436    RecordNextStmtCount = true;
437  }
438
439  void VisitGotoStmt(const GotoStmt *S) {
440    RecordStmtCount(S);
441    CurrentCount = 0;
442    RecordNextStmtCount = true;
443  }
444
445  void VisitLabelStmt(const LabelStmt *S) {
446    RecordNextStmtCount = false;
447    // Counter tracks the block following the label.
448    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
449    CountMap[S] = BlockCount;
450    Visit(S->getSubStmt());
451  }
452
453  void VisitBreakStmt(const BreakStmt *S) {
454    RecordStmtCount(S);
455    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
456    BreakContinueStack.back().BreakCount += CurrentCount;
457    CurrentCount = 0;
458    RecordNextStmtCount = true;
459  }
460
461  void VisitContinueStmt(const ContinueStmt *S) {
462    RecordStmtCount(S);
463    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
464    BreakContinueStack.back().ContinueCount += CurrentCount;
465    CurrentCount = 0;
466    RecordNextStmtCount = true;
467  }
468
469  void VisitWhileStmt(const WhileStmt *S) {
470    RecordStmtCount(S);
471    uint64_t ParentCount = CurrentCount;
472
473    BreakContinueStack.push_back(BreakContinue());
474    // Visit the body region first so the break/continue adjustments can be
475    // included when visiting the condition.
476    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
477    CountMap[S->getBody()] = CurrentCount;
478    Visit(S->getBody());
479    uint64_t BackedgeCount = CurrentCount;
480
481    // ...then go back and propagate counts through the condition. The count
482    // at the start of the condition is the sum of the incoming edges,
483    // the backedge from the end of the loop body, and the edges from
484    // continue statements.
485    BreakContinue BC = BreakContinueStack.pop_back_val();
486    uint64_t CondCount =
487        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
488    CountMap[S->getCond()] = CondCount;
489    Visit(S->getCond());
490    setCount(BC.BreakCount + CondCount - BodyCount);
491    RecordNextStmtCount = true;
492  }
493
494  void VisitDoStmt(const DoStmt *S) {
495    RecordStmtCount(S);
496    uint64_t LoopCount = PGO.getRegionCount(S);
497
498    BreakContinueStack.push_back(BreakContinue());
499    // The count doesn't include the fallthrough from the parent scope. Add it.
500    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
501    CountMap[S->getBody()] = BodyCount;
502    Visit(S->getBody());
503    uint64_t BackedgeCount = CurrentCount;
504
505    BreakContinue BC = BreakContinueStack.pop_back_val();
506    // The count at the start of the condition is equal to the count at the
507    // end of the body, plus any continues.
508    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
509    CountMap[S->getCond()] = CondCount;
510    Visit(S->getCond());
511    setCount(BC.BreakCount + CondCount - LoopCount);
512    RecordNextStmtCount = true;
513  }
514
515  void VisitForStmt(const ForStmt *S) {
516    RecordStmtCount(S);
517    if (S->getInit())
518      Visit(S->getInit());
519
520    uint64_t ParentCount = CurrentCount;
521
522    BreakContinueStack.push_back(BreakContinue());
523    // Visit the body region first. (This is basically the same as a while
524    // loop; see further comments in VisitWhileStmt.)
525    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
526    CountMap[S->getBody()] = BodyCount;
527    Visit(S->getBody());
528    uint64_t BackedgeCount = CurrentCount;
529    BreakContinue BC = BreakContinueStack.pop_back_val();
530
531    // The increment is essentially part of the body but it needs to include
532    // the count for all the continue statements.
533    if (S->getInc()) {
534      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
535      CountMap[S->getInc()] = IncCount;
536      Visit(S->getInc());
537    }
538
539    // ...then go back and propagate counts through the condition.
540    uint64_t CondCount =
541        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
542    if (S->getCond()) {
543      CountMap[S->getCond()] = CondCount;
544      Visit(S->getCond());
545    }
546    setCount(BC.BreakCount + CondCount - BodyCount);
547    RecordNextStmtCount = true;
548  }
549
550  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
551    RecordStmtCount(S);
552    if (S->getInit())
553      Visit(S->getInit());
554    Visit(S->getLoopVarStmt());
555    Visit(S->getRangeStmt());
556    Visit(S->getBeginStmt());
557    Visit(S->getEndStmt());
558
559    uint64_t ParentCount = CurrentCount;
560    BreakContinueStack.push_back(BreakContinue());
561    // Visit the body region first. (This is basically the same as a while
562    // loop; see further comments in VisitWhileStmt.)
563    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
564    CountMap[S->getBody()] = BodyCount;
565    Visit(S->getBody());
566    uint64_t BackedgeCount = CurrentCount;
567    BreakContinue BC = BreakContinueStack.pop_back_val();
568
569    // The increment is essentially part of the body but it needs to include
570    // the count for all the continue statements.
571    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
572    CountMap[S->getInc()] = IncCount;
573    Visit(S->getInc());
574
575    // ...then go back and propagate counts through the condition.
576    uint64_t CondCount =
577        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
578    CountMap[S->getCond()] = CondCount;
579    Visit(S->getCond());
580    setCount(BC.BreakCount + CondCount - BodyCount);
581    RecordNextStmtCount = true;
582  }
583
584  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
585    RecordStmtCount(S);
586    Visit(S->getElement());
587    uint64_t ParentCount = CurrentCount;
588    BreakContinueStack.push_back(BreakContinue());
589    // Counter tracks the body of the loop.
590    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
591    CountMap[S->getBody()] = BodyCount;
592    Visit(S->getBody());
593    uint64_t BackedgeCount = CurrentCount;
594    BreakContinue BC = BreakContinueStack.pop_back_val();
595
596    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
597             BodyCount);
598    RecordNextStmtCount = true;
599  }
600
601  void VisitSwitchStmt(const SwitchStmt *S) {
602    RecordStmtCount(S);
603    if (S->getInit())
604      Visit(S->getInit());
605    Visit(S->getCond());
606    CurrentCount = 0;
607    BreakContinueStack.push_back(BreakContinue());
608    Visit(S->getBody());
609    // If the switch is inside a loop, add the continue counts.
610    BreakContinue BC = BreakContinueStack.pop_back_val();
611    if (!BreakContinueStack.empty())
612      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
613    // Counter tracks the exit block of the switch.
614    setCount(PGO.getRegionCount(S));
615    RecordNextStmtCount = true;
616  }
617
618  void VisitSwitchCase(const SwitchCase *S) {
619    RecordNextStmtCount = false;
620    // Counter for this particular case. This counts only jumps from the
621    // switch header and does not include fallthrough from the case before
622    // this one.
623    uint64_t CaseCount = PGO.getRegionCount(S);
624    setCount(CurrentCount + CaseCount);
625    // We need the count without fallthrough in the mapping, so it's more useful
626    // for branch probabilities.
627    CountMap[S] = CaseCount;
628    RecordNextStmtCount = true;
629    Visit(S->getSubStmt());
630  }
631
632  void VisitIfStmt(const IfStmt *S) {
633    RecordStmtCount(S);
634    uint64_t ParentCount = CurrentCount;
635    if (S->getInit())
636      Visit(S->getInit());
637    Visit(S->getCond());
638
639    // Counter tracks the "then" part of an if statement. The count for
640    // the "else" part, if it exists, will be calculated from this counter.
641    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
642    CountMap[S->getThen()] = ThenCount;
643    Visit(S->getThen());
644    uint64_t OutCount = CurrentCount;
645
646    uint64_t ElseCount = ParentCount - ThenCount;
647    if (S->getElse()) {
648      setCount(ElseCount);
649      CountMap[S->getElse()] = ElseCount;
650      Visit(S->getElse());
651      OutCount += CurrentCount;
652    } else
653      OutCount += ElseCount;
654    setCount(OutCount);
655    RecordNextStmtCount = true;
656  }
657
658  void VisitCXXTryStmt(const CXXTryStmt *S) {
659    RecordStmtCount(S);
660    Visit(S->getTryBlock());
661    for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
662      Visit(S->getHandler(I));
663    // Counter tracks the continuation block of the try statement.
664    setCount(PGO.getRegionCount(S));
665    RecordNextStmtCount = true;
666  }
667
668  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
669    RecordNextStmtCount = false;
670    // Counter tracks the catch statement's handler block.
671    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
672    CountMap[S] = CatchCount;
673    Visit(S->getHandlerBlock());
674  }
675
676  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
677    RecordStmtCount(E);
678    uint64_t ParentCount = CurrentCount;
679    Visit(E->getCond());
680
681    // Counter tracks the "true" part of a conditional operator. The
682    // count in the "false" part will be calculated from this counter.
683    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
684    CountMap[E->getTrueExpr()] = TrueCount;
685    Visit(E->getTrueExpr());
686    uint64_t OutCount = CurrentCount;
687
688    uint64_t FalseCount = setCount(ParentCount - TrueCount);
689    CountMap[E->getFalseExpr()] = FalseCount;
690    Visit(E->getFalseExpr());
691    OutCount += CurrentCount;
692
693    setCount(OutCount);
694    RecordNextStmtCount = true;
695  }
696
697  void VisitBinLAnd(const BinaryOperator *E) {
698    RecordStmtCount(E);
699    uint64_t ParentCount = CurrentCount;
700    Visit(E->getLHS());
701    // Counter tracks the right hand side of a logical and operator.
702    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
703    CountMap[E->getRHS()] = RHSCount;
704    Visit(E->getRHS());
705    setCount(ParentCount + RHSCount - CurrentCount);
706    RecordNextStmtCount = true;
707  }
708
709  void VisitBinLOr(const BinaryOperator *E) {
710    RecordStmtCount(E);
711    uint64_t ParentCount = CurrentCount;
712    Visit(E->getLHS());
713    // Counter tracks the right hand side of a logical or operator.
714    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
715    CountMap[E->getRHS()] = RHSCount;
716    Visit(E->getRHS());
717    setCount(ParentCount + RHSCount - CurrentCount);
718    RecordNextStmtCount = true;
719  }
720};
721} // end anonymous namespace
722
723void PGOHash::combine(HashType Type) {
724  // Check that we never combine 0 and only have six bits.
725  assert(Type && "Hash is invalid: unexpected type 0");
726  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
727
728  // Pass through MD5 if enough work has built up.
729  if (Count && Count % NumTypesPerWord == 0) {
730    using namespace llvm::support;
731    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
732    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
733    Working = 0;
734  }
735
736  // Accumulate the current type.
737  ++Count;
738  Working = Working << NumBitsPerType | Type;
739}
740
741uint64_t PGOHash::finalize() {
742  // Use Working as the hash directly if we never used MD5.
743  if (Count <= NumTypesPerWord)
744    // No need to byte swap here, since none of the math was endian-dependent.
745    // This number will be byte-swapped as required on endianness transitions,
746    // so we will see the same value on the other side.
747    return Working;
748
749  // Check for remaining work in Working.
750  if (Working)
751    MD5.update(Working);
752
753  // Finalize the MD5 and return the hash.
754  llvm::MD5::MD5Result Result;
755  MD5.final(Result);
756  using namespace llvm::support;
757  return Result.low();
758}
759
760void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
761  const Decl *D = GD.getDecl();
762  if (!D->hasBody())
763    return;
764
765  bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
766  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
767  if (!InstrumentRegions && !PGOReader)
768    return;
769  if (D->isImplicit())
770    return;
771  // Constructors and destructors may be represented by several functions in IR.
772  // If so, instrument only base variant, others are implemented by delegation
773  // to the base one, it would be counted twice otherwise.
774  if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
775    if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
776      if (GD.getCtorType() != Ctor_Base &&
777          CodeGenFunction::IsConstructorDelegationValid(CCD))
778        return;
779  }
780  if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
781    return;
782
783  CGM.ClearUnusedCoverageMapping(D);
784  setFuncName(Fn);
785
786  mapRegionCounters(D);
787  if (CGM.getCodeGenOpts().CoverageMapping)
788    emitCounterRegionMapping(D);
789  if (PGOReader) {
790    SourceManager &SM = CGM.getContext().getSourceManager();
791    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
792    computeRegionCounts(D);
793    applyFunctionAttributes(PGOReader, Fn);
794  }
795}
796
797void CodeGenPGO::mapRegionCounters(const Decl *D) {
798  // Use the latest hash version when inserting instrumentation, but use the
799  // version in the indexed profile if we're reading PGO data.
800  PGOHashVersion HashVersion = PGO_HASH_LATEST;
801  if (auto *PGOReader = CGM.getPGOReader())
802    HashVersion = getPGOHashVersion(PGOReader, CGM);
803
804  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
805  MapRegionCounters Walker(HashVersion, *RegionCounterMap);
806  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
807    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
808  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
809    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
810  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
811    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
812  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
813    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
814  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
815  NumRegionCounters = Walker.NextCounter;
816  FunctionHash = Walker.Hash.finalize();
817}
818
819bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
820  if (!D->getBody())
821    return true;
822
823  // Don't map the functions in system headers.
824  const auto &SM = CGM.getContext().getSourceManager();
825  auto Loc = D->getBody()->getBeginLoc();
826  return SM.isInSystemHeader(Loc);
827}
828
829void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
830  if (skipRegionMappingForDecl(D))
831    return;
832
833  std::string CoverageMapping;
834  llvm::raw_string_ostream OS(CoverageMapping);
835  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
836                                CGM.getContext().getSourceManager(),
837                                CGM.getLangOpts(), RegionCounterMap.get());
838  MappingGen.emitCounterMapping(D, OS);
839  OS.flush();
840
841  if (CoverageMapping.empty())
842    return;
843
844  CGM.getCoverageMapping()->addFunctionMappingRecord(
845      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
846}
847
848void
849CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
850                                    llvm::GlobalValue::LinkageTypes Linkage) {
851  if (skipRegionMappingForDecl(D))
852    return;
853
854  std::string CoverageMapping;
855  llvm::raw_string_ostream OS(CoverageMapping);
856  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
857                                CGM.getContext().getSourceManager(),
858                                CGM.getLangOpts());
859  MappingGen.emitEmptyMapping(D, OS);
860  OS.flush();
861
862  if (CoverageMapping.empty())
863    return;
864
865  setFuncName(Name, Linkage);
866  CGM.getCoverageMapping()->addFunctionMappingRecord(
867      FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
868}
869
870void CodeGenPGO::computeRegionCounts(const Decl *D) {
871  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
872  ComputeRegionCounts Walker(*StmtCountMap, *this);
873  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
874    Walker.VisitFunctionDecl(FD);
875  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
876    Walker.VisitObjCMethodDecl(MD);
877  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
878    Walker.VisitBlockDecl(BD);
879  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
880    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
881}
882
883void
884CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
885                                    llvm::Function *Fn) {
886  if (!haveRegionCounts())
887    return;
888
889  uint64_t FunctionCount = getRegionCount(nullptr);
890  Fn->setEntryCount(FunctionCount);
891}
892
893void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
894                                      llvm::Value *StepV) {
895  if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
896    return;
897  if (!Builder.GetInsertBlock())
898    return;
899
900  unsigned Counter = (*RegionCounterMap)[S];
901  auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
902
903  llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
904                         Builder.getInt64(FunctionHash),
905                         Builder.getInt32(NumRegionCounters),
906                         Builder.getInt32(Counter), StepV};
907  if (!StepV)
908    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
909                       makeArrayRef(Args, 4));
910  else
911    Builder.CreateCall(
912        CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
913        makeArrayRef(Args));
914}
915
916// This method either inserts a call to the profile run-time during
917// instrumentation or puts profile data into metadata for PGO use.
918void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
919    llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
920
921  if (!EnableValueProfiling)
922    return;
923
924  if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
925    return;
926
927  if (isa<llvm::Constant>(ValuePtr))
928    return;
929
930  bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
931  if (InstrumentValueSites && RegionCounterMap) {
932    auto BuilderInsertPoint = Builder.saveIP();
933    Builder.SetInsertPoint(ValueSite);
934    llvm::Value *Args[5] = {
935        llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
936        Builder.getInt64(FunctionHash),
937        Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
938        Builder.getInt32(ValueKind),
939        Builder.getInt32(NumValueSites[ValueKind]++)
940    };
941    Builder.CreateCall(
942        CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
943    Builder.restoreIP(BuilderInsertPoint);
944    return;
945  }
946
947  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
948  if (PGOReader && haveRegionCounts()) {
949    // We record the top most called three functions at each call site.
950    // Profile metadata contains "VP" string identifying this metadata
951    // as value profiling data, then a uint32_t value for the value profiling
952    // kind, a uint64_t value for the total number of times the call is
953    // executed, followed by the function hash and execution count (uint64_t)
954    // pairs for each function.
955    if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
956      return;
957
958    llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
959                            (llvm::InstrProfValueKind)ValueKind,
960                            NumValueSites[ValueKind]);
961
962    NumValueSites[ValueKind]++;
963  }
964}
965
966void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
967                                  bool IsInMainFile) {
968  CGM.getPGOStats().addVisited(IsInMainFile);
969  RegionCounts.clear();
970  llvm::Expected<llvm::InstrProfRecord> RecordExpected =
971      PGOReader->getInstrProfRecord(FuncName, FunctionHash);
972  if (auto E = RecordExpected.takeError()) {
973    auto IPE = llvm::InstrProfError::take(std::move(E));
974    if (IPE == llvm::instrprof_error::unknown_function)
975      CGM.getPGOStats().addMissing(IsInMainFile);
976    else if (IPE == llvm::instrprof_error::hash_mismatch)
977      CGM.getPGOStats().addMismatched(IsInMainFile);
978    else if (IPE == llvm::instrprof_error::malformed)
979      // TODO: Consider a more specific warning for this case.
980      CGM.getPGOStats().addMismatched(IsInMainFile);
981    return;
982  }
983  ProfRecord =
984      std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
985  RegionCounts = ProfRecord->Counts;
986}
987
988/// Calculate what to divide by to scale weights.
989///
990/// Given the maximum weight, calculate a divisor that will scale all the
991/// weights to strictly less than UINT32_MAX.
992static uint64_t calculateWeightScale(uint64_t MaxWeight) {
993  return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
994}
995
996/// Scale an individual branch weight (and add 1).
997///
998/// Scale a 64-bit weight down to 32-bits using \c Scale.
999///
1000/// According to Laplace's Rule of Succession, it is better to compute the
1001/// weight based on the count plus 1, so universally add 1 to the value.
1002///
1003/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1004/// greater than \c Weight.
1005static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1006  assert(Scale && "scale by 0?");
1007  uint64_t Scaled = Weight / Scale + 1;
1008  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1009  return Scaled;
1010}
1011
1012llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1013                                                    uint64_t FalseCount) {
1014  // Check for empty weights.
1015  if (!TrueCount && !FalseCount)
1016    return nullptr;
1017
1018  // Calculate how to scale down to 32-bits.
1019  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1020
1021  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1022  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1023                                      scaleBranchWeight(FalseCount, Scale));
1024}
1025
1026llvm::MDNode *
1027CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
1028  // We need at least two elements to create meaningful weights.
1029  if (Weights.size() < 2)
1030    return nullptr;
1031
1032  // Check for empty weights.
1033  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1034  if (MaxWeight == 0)
1035    return nullptr;
1036
1037  // Calculate how to scale down to 32-bits.
1038  uint64_t Scale = calculateWeightScale(MaxWeight);
1039
1040  SmallVector<uint32_t, 16> ScaledWeights;
1041  ScaledWeights.reserve(Weights.size());
1042  for (uint64_t W : Weights)
1043    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1044
1045  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1046  return MDHelper.createBranchWeights(ScaledWeights);
1047}
1048
1049llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1050                                                           uint64_t LoopCount) {
1051  if (!PGO.haveRegionCounts())
1052    return nullptr;
1053  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1054  assert(CondCount.hasValue() && "missing expected loop condition count");
1055  if (*CondCount == 0)
1056    return nullptr;
1057  return createProfileWeights(LoopCount,
1058                              std::max(*CondCount, LoopCount) - LoopCount);
1059}
1060