1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CodeGenFunction.h"
15#include "CoverageMappingGen.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Endian.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/MD5.h"
24#include <optional>
25
26static llvm::cl::opt<bool>
27    EnableValueProfiling("enable-value-profiling",
28                         llvm::cl::desc("Enable value profiling"),
29                         llvm::cl::Hidden, llvm::cl::init(false));
30
31extern llvm::cl::opt<bool> SystemHeadersCoverage;
32
33using namespace clang;
34using namespace CodeGen;
35
36void CodeGenPGO::setFuncName(StringRef Name,
37                             llvm::GlobalValue::LinkageTypes Linkage) {
38  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
39  FuncName = llvm::getPGOFuncName(
40      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
41      PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
42
43  // If we're generating a profile, create a variable for the name.
44  if (CGM.getCodeGenOpts().hasProfileClangInstr())
45    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
46}
47
48void CodeGenPGO::setFuncName(llvm::Function *Fn) {
49  setFuncName(Fn->getName(), Fn->getLinkage());
50  // Create PGOFuncName meta data.
51  llvm::createPGOFuncNameMetadata(*Fn, FuncName);
52}
53
54/// The version of the PGO hash algorithm.
55enum PGOHashVersion : unsigned {
56  PGO_HASH_V1,
57  PGO_HASH_V2,
58  PGO_HASH_V3,
59
60  // Keep this set to the latest hash version.
61  PGO_HASH_LATEST = PGO_HASH_V3
62};
63
64namespace {
65/// Stable hasher for PGO region counters.
66///
67/// PGOHash produces a stable hash of a given function's control flow.
68///
69/// Changing the output of this hash will invalidate all previously generated
70/// profiles -- i.e., don't do it.
71///
72/// \note  When this hash does eventually change (years?), we still need to
73/// support old hashes.  We'll need to pull in the version number from the
74/// profile data format and use the matching hash function.
75class PGOHash {
76  uint64_t Working;
77  unsigned Count;
78  PGOHashVersion HashVersion;
79  llvm::MD5 MD5;
80
81  static const int NumBitsPerType = 6;
82  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
83  static const unsigned TooBig = 1u << NumBitsPerType;
84
85public:
86  /// Hash values for AST nodes.
87  ///
88  /// Distinct values for AST nodes that have region counters attached.
89  ///
90  /// These values must be stable.  All new members must be added at the end,
91  /// and no members should be removed.  Changing the enumeration value for an
92  /// AST node will affect the hash of every function that contains that node.
93  enum HashType : unsigned char {
94    None = 0,
95    LabelStmt = 1,
96    WhileStmt,
97    DoStmt,
98    ForStmt,
99    CXXForRangeStmt,
100    ObjCForCollectionStmt,
101    SwitchStmt,
102    CaseStmt,
103    DefaultStmt,
104    IfStmt,
105    CXXTryStmt,
106    CXXCatchStmt,
107    ConditionalOperator,
108    BinaryOperatorLAnd,
109    BinaryOperatorLOr,
110    BinaryConditionalOperator,
111    // The preceding values are available with PGO_HASH_V1.
112
113    EndOfScope,
114    IfThenBranch,
115    IfElseBranch,
116    GotoStmt,
117    IndirectGotoStmt,
118    BreakStmt,
119    ContinueStmt,
120    ReturnStmt,
121    ThrowExpr,
122    UnaryOperatorLNot,
123    BinaryOperatorLT,
124    BinaryOperatorGT,
125    BinaryOperatorLE,
126    BinaryOperatorGE,
127    BinaryOperatorEQ,
128    BinaryOperatorNE,
129    // The preceding values are available since PGO_HASH_V2.
130
131    // Keep this last.  It's for the static assert that follows.
132    LastHashType
133  };
134  static_assert(LastHashType <= TooBig, "Too many types in HashType");
135
136  PGOHash(PGOHashVersion HashVersion)
137      : Working(0), Count(0), HashVersion(HashVersion) {}
138  void combine(HashType Type);
139  uint64_t finalize();
140  PGOHashVersion getHashVersion() const { return HashVersion; }
141};
142const int PGOHash::NumBitsPerType;
143const unsigned PGOHash::NumTypesPerWord;
144const unsigned PGOHash::TooBig;
145
146/// Get the PGO hash version used in the given indexed profile.
147static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
148                                        CodeGenModule &CGM) {
149  if (PGOReader->getVersion() <= 4)
150    return PGO_HASH_V1;
151  if (PGOReader->getVersion() <= 5)
152    return PGO_HASH_V2;
153  return PGO_HASH_V3;
154}
155
156/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
157struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
158  using Base = RecursiveASTVisitor<MapRegionCounters>;
159
160  /// The next counter value to assign.
161  unsigned NextCounter;
162  /// The function hash.
163  PGOHash Hash;
164  /// The map of statements to counters.
165  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
166  /// The next bitmap byte index to assign.
167  unsigned NextMCDCBitmapIdx;
168  /// The map of statements to MC/DC bitmap coverage objects.
169  llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap;
170  /// Maximum number of supported MC/DC conditions in a boolean expression.
171  unsigned MCDCMaxCond;
172  /// The profile version.
173  uint64_t ProfileVersion;
174  /// Diagnostics Engine used to report warnings.
175  DiagnosticsEngine &Diag;
176
177  MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
178                    llvm::DenseMap<const Stmt *, unsigned> &CounterMap,
179                    llvm::DenseMap<const Stmt *, unsigned> &MCDCBitmapMap,
180                    unsigned MCDCMaxCond, DiagnosticsEngine &Diag)
181      : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
182        NextMCDCBitmapIdx(0), MCDCBitmapMap(MCDCBitmapMap),
183        MCDCMaxCond(MCDCMaxCond), ProfileVersion(ProfileVersion), Diag(Diag) {}
184
185  // Blocks and lambdas are handled as separate functions, so we need not
186  // traverse them in the parent context.
187  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
188  bool TraverseLambdaExpr(LambdaExpr *LE) {
189    // Traverse the captures, but not the body.
190    for (auto C : zip(LE->captures(), LE->capture_inits()))
191      TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
192    return true;
193  }
194  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
195
196  bool VisitDecl(const Decl *D) {
197    switch (D->getKind()) {
198    default:
199      break;
200    case Decl::Function:
201    case Decl::CXXMethod:
202    case Decl::CXXConstructor:
203    case Decl::CXXDestructor:
204    case Decl::CXXConversion:
205    case Decl::ObjCMethod:
206    case Decl::Block:
207    case Decl::Captured:
208      CounterMap[D->getBody()] = NextCounter++;
209      break;
210    }
211    return true;
212  }
213
214  /// If \p S gets a fresh counter, update the counter mappings. Return the
215  /// V1 hash of \p S.
216  PGOHash::HashType updateCounterMappings(Stmt *S) {
217    auto Type = getHashType(PGO_HASH_V1, S);
218    if (Type != PGOHash::None)
219      CounterMap[S] = NextCounter++;
220    return Type;
221  }
222
223  /// The following stacks are used with dataTraverseStmtPre() and
224  /// dataTraverseStmtPost() to track the depth of nested logical operators in a
225  /// boolean expression in a function.  The ultimate purpose is to keep track
226  /// of the number of leaf-level conditions in the boolean expression so that a
227  /// profile bitmap can be allocated based on that number.
228  ///
229  /// The stacks are also used to find error cases and notify the user.  A
230  /// standard logical operator nest for a boolean expression could be in a form
231  /// similar to this: "x = a && b && c && (d || f)"
232  unsigned NumCond = 0;
233  bool SplitNestedLogicalOp = false;
234  SmallVector<const Stmt *, 16> NonLogOpStack;
235  SmallVector<const BinaryOperator *, 16> LogOpStack;
236
237  // Hook: dataTraverseStmtPre() is invoked prior to visiting an AST Stmt node.
238  bool dataTraverseStmtPre(Stmt *S) {
239    /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
240    if (MCDCMaxCond == 0)
241      return true;
242
243    /// At the top of the logical operator nest, reset the number of conditions,
244    /// also forget previously seen split nesting cases.
245    if (LogOpStack.empty()) {
246      NumCond = 0;
247      SplitNestedLogicalOp = false;
248    }
249
250    if (const Expr *E = dyn_cast<Expr>(S)) {
251      const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
252      if (BinOp && BinOp->isLogicalOp()) {
253        /// Check for "split-nested" logical operators. This happens when a new
254        /// boolean expression logical-op nest is encountered within an existing
255        /// boolean expression, separated by a non-logical operator.  For
256        /// example, in "x = (a && b && c && foo(d && f))", the "d && f" case
257        /// starts a new boolean expression that is separated from the other
258        /// conditions by the operator foo(). Split-nested cases are not
259        /// supported by MC/DC.
260        SplitNestedLogicalOp = SplitNestedLogicalOp || !NonLogOpStack.empty();
261
262        LogOpStack.push_back(BinOp);
263        return true;
264      }
265    }
266
267    /// Keep track of non-logical operators. These are OK as long as we don't
268    /// encounter a new logical operator after seeing one.
269    if (!LogOpStack.empty())
270      NonLogOpStack.push_back(S);
271
272    return true;
273  }
274
275  // Hook: dataTraverseStmtPost() is invoked by the AST visitor after visiting
276  // an AST Stmt node.  MC/DC will use it to to signal when the top of a
277  // logical operation (boolean expression) nest is encountered.
278  bool dataTraverseStmtPost(Stmt *S) {
279    /// If MC/DC is not enabled, MCDCMaxCond will be set to 0. Do nothing.
280    if (MCDCMaxCond == 0)
281      return true;
282
283    if (const Expr *E = dyn_cast<Expr>(S)) {
284      const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(E->IgnoreParens());
285      if (BinOp && BinOp->isLogicalOp()) {
286        assert(LogOpStack.back() == BinOp);
287        LogOpStack.pop_back();
288
289        /// At the top of logical operator nest:
290        if (LogOpStack.empty()) {
291          /// Was the "split-nested" logical operator case encountered?
292          if (SplitNestedLogicalOp) {
293            unsigned DiagID = Diag.getCustomDiagID(
294                DiagnosticsEngine::Warning,
295                "unsupported MC/DC boolean expression; "
296                "contains an operation with a nested boolean expression. "
297                "Expression will not be covered");
298            Diag.Report(S->getBeginLoc(), DiagID);
299            return true;
300          }
301
302          /// Was the maximum number of conditions encountered?
303          if (NumCond > MCDCMaxCond) {
304            unsigned DiagID = Diag.getCustomDiagID(
305                DiagnosticsEngine::Warning,
306                "unsupported MC/DC boolean expression; "
307                "number of conditions (%0) exceeds max (%1). "
308                "Expression will not be covered");
309            Diag.Report(S->getBeginLoc(), DiagID) << NumCond << MCDCMaxCond;
310            return true;
311          }
312
313          // Otherwise, allocate the number of bytes required for the bitmap
314          // based on the number of conditions. Must be at least 1-byte long.
315          MCDCBitmapMap[BinOp] = NextMCDCBitmapIdx;
316          unsigned SizeInBits = std::max<unsigned>(1L << NumCond, CHAR_BIT);
317          NextMCDCBitmapIdx += SizeInBits / CHAR_BIT;
318        }
319        return true;
320      }
321    }
322
323    if (!LogOpStack.empty())
324      NonLogOpStack.pop_back();
325
326    return true;
327  }
328
329  /// The RHS of all logical operators gets a fresh counter in order to count
330  /// how many times the RHS evaluates to true or false, depending on the
331  /// semantics of the operator. This is only valid for ">= v7" of the profile
332  /// version so that we facilitate backward compatibility. In addition, in
333  /// order to use MC/DC, count the number of total LHS and RHS conditions.
334  bool VisitBinaryOperator(BinaryOperator *S) {
335    if (S->isLogicalOp()) {
336      if (CodeGenFunction::isInstrumentedCondition(S->getLHS()))
337        NumCond++;
338
339      if (CodeGenFunction::isInstrumentedCondition(S->getRHS())) {
340        if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
341          CounterMap[S->getRHS()] = NextCounter++;
342
343        NumCond++;
344      }
345    }
346    return Base::VisitBinaryOperator(S);
347  }
348
349  /// Include \p S in the function hash.
350  bool VisitStmt(Stmt *S) {
351    auto Type = updateCounterMappings(S);
352    if (Hash.getHashVersion() != PGO_HASH_V1)
353      Type = getHashType(Hash.getHashVersion(), S);
354    if (Type != PGOHash::None)
355      Hash.combine(Type);
356    return true;
357  }
358
359  bool TraverseIfStmt(IfStmt *If) {
360    // If we used the V1 hash, use the default traversal.
361    if (Hash.getHashVersion() == PGO_HASH_V1)
362      return Base::TraverseIfStmt(If);
363
364    // Otherwise, keep track of which branch we're in while traversing.
365    VisitStmt(If);
366    for (Stmt *CS : If->children()) {
367      if (!CS)
368        continue;
369      if (CS == If->getThen())
370        Hash.combine(PGOHash::IfThenBranch);
371      else if (CS == If->getElse())
372        Hash.combine(PGOHash::IfElseBranch);
373      TraverseStmt(CS);
374    }
375    Hash.combine(PGOHash::EndOfScope);
376    return true;
377  }
378
379// If the statement type \p N is nestable, and its nesting impacts profile
380// stability, define a custom traversal which tracks the end of the statement
381// in the hash (provided we're not using the V1 hash).
382#define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
383  bool Traverse##N(N *S) {                                                     \
384    Base::Traverse##N(S);                                                      \
385    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
386      Hash.combine(PGOHash::EndOfScope);                                       \
387    return true;                                                               \
388  }
389
390  DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
391  DEFINE_NESTABLE_TRAVERSAL(DoStmt)
392  DEFINE_NESTABLE_TRAVERSAL(ForStmt)
393  DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
394  DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
395  DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
396  DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
397
398  /// Get version \p HashVersion of the PGO hash for \p S.
399  PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
400    switch (S->getStmtClass()) {
401    default:
402      break;
403    case Stmt::LabelStmtClass:
404      return PGOHash::LabelStmt;
405    case Stmt::WhileStmtClass:
406      return PGOHash::WhileStmt;
407    case Stmt::DoStmtClass:
408      return PGOHash::DoStmt;
409    case Stmt::ForStmtClass:
410      return PGOHash::ForStmt;
411    case Stmt::CXXForRangeStmtClass:
412      return PGOHash::CXXForRangeStmt;
413    case Stmt::ObjCForCollectionStmtClass:
414      return PGOHash::ObjCForCollectionStmt;
415    case Stmt::SwitchStmtClass:
416      return PGOHash::SwitchStmt;
417    case Stmt::CaseStmtClass:
418      return PGOHash::CaseStmt;
419    case Stmt::DefaultStmtClass:
420      return PGOHash::DefaultStmt;
421    case Stmt::IfStmtClass:
422      return PGOHash::IfStmt;
423    case Stmt::CXXTryStmtClass:
424      return PGOHash::CXXTryStmt;
425    case Stmt::CXXCatchStmtClass:
426      return PGOHash::CXXCatchStmt;
427    case Stmt::ConditionalOperatorClass:
428      return PGOHash::ConditionalOperator;
429    case Stmt::BinaryConditionalOperatorClass:
430      return PGOHash::BinaryConditionalOperator;
431    case Stmt::BinaryOperatorClass: {
432      const BinaryOperator *BO = cast<BinaryOperator>(S);
433      if (BO->getOpcode() == BO_LAnd)
434        return PGOHash::BinaryOperatorLAnd;
435      if (BO->getOpcode() == BO_LOr)
436        return PGOHash::BinaryOperatorLOr;
437      if (HashVersion >= PGO_HASH_V2) {
438        switch (BO->getOpcode()) {
439        default:
440          break;
441        case BO_LT:
442          return PGOHash::BinaryOperatorLT;
443        case BO_GT:
444          return PGOHash::BinaryOperatorGT;
445        case BO_LE:
446          return PGOHash::BinaryOperatorLE;
447        case BO_GE:
448          return PGOHash::BinaryOperatorGE;
449        case BO_EQ:
450          return PGOHash::BinaryOperatorEQ;
451        case BO_NE:
452          return PGOHash::BinaryOperatorNE;
453        }
454      }
455      break;
456    }
457    }
458
459    if (HashVersion >= PGO_HASH_V2) {
460      switch (S->getStmtClass()) {
461      default:
462        break;
463      case Stmt::GotoStmtClass:
464        return PGOHash::GotoStmt;
465      case Stmt::IndirectGotoStmtClass:
466        return PGOHash::IndirectGotoStmt;
467      case Stmt::BreakStmtClass:
468        return PGOHash::BreakStmt;
469      case Stmt::ContinueStmtClass:
470        return PGOHash::ContinueStmt;
471      case Stmt::ReturnStmtClass:
472        return PGOHash::ReturnStmt;
473      case Stmt::CXXThrowExprClass:
474        return PGOHash::ThrowExpr;
475      case Stmt::UnaryOperatorClass: {
476        const UnaryOperator *UO = cast<UnaryOperator>(S);
477        if (UO->getOpcode() == UO_LNot)
478          return PGOHash::UnaryOperatorLNot;
479        break;
480      }
481      }
482    }
483
484    return PGOHash::None;
485  }
486};
487
488/// A StmtVisitor that propagates the raw counts through the AST and
489/// records the count at statements where the value may change.
490struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
491  /// PGO state.
492  CodeGenPGO &PGO;
493
494  /// A flag that is set when the current count should be recorded on the
495  /// next statement, such as at the exit of a loop.
496  bool RecordNextStmtCount;
497
498  /// The count at the current location in the traversal.
499  uint64_t CurrentCount;
500
501  /// The map of statements to count values.
502  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
503
504  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
505  struct BreakContinue {
506    uint64_t BreakCount = 0;
507    uint64_t ContinueCount = 0;
508    BreakContinue() = default;
509  };
510  SmallVector<BreakContinue, 8> BreakContinueStack;
511
512  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
513                      CodeGenPGO &PGO)
514      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
515
516  void RecordStmtCount(const Stmt *S) {
517    if (RecordNextStmtCount) {
518      CountMap[S] = CurrentCount;
519      RecordNextStmtCount = false;
520    }
521  }
522
523  /// Set and return the current count.
524  uint64_t setCount(uint64_t Count) {
525    CurrentCount = Count;
526    return Count;
527  }
528
529  void VisitStmt(const Stmt *S) {
530    RecordStmtCount(S);
531    for (const Stmt *Child : S->children())
532      if (Child)
533        this->Visit(Child);
534  }
535
536  void VisitFunctionDecl(const FunctionDecl *D) {
537    // Counter tracks entry to the function body.
538    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
539    CountMap[D->getBody()] = BodyCount;
540    Visit(D->getBody());
541  }
542
543  // Skip lambda expressions. We visit these as FunctionDecls when we're
544  // generating them and aren't interested in the body when generating a
545  // parent context.
546  void VisitLambdaExpr(const LambdaExpr *LE) {}
547
548  void VisitCapturedDecl(const CapturedDecl *D) {
549    // Counter tracks entry to the capture body.
550    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
551    CountMap[D->getBody()] = BodyCount;
552    Visit(D->getBody());
553  }
554
555  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
556    // Counter tracks entry to the method body.
557    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
558    CountMap[D->getBody()] = BodyCount;
559    Visit(D->getBody());
560  }
561
562  void VisitBlockDecl(const BlockDecl *D) {
563    // Counter tracks entry to the block body.
564    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
565    CountMap[D->getBody()] = BodyCount;
566    Visit(D->getBody());
567  }
568
569  void VisitReturnStmt(const ReturnStmt *S) {
570    RecordStmtCount(S);
571    if (S->getRetValue())
572      Visit(S->getRetValue());
573    CurrentCount = 0;
574    RecordNextStmtCount = true;
575  }
576
577  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
578    RecordStmtCount(E);
579    if (E->getSubExpr())
580      Visit(E->getSubExpr());
581    CurrentCount = 0;
582    RecordNextStmtCount = true;
583  }
584
585  void VisitGotoStmt(const GotoStmt *S) {
586    RecordStmtCount(S);
587    CurrentCount = 0;
588    RecordNextStmtCount = true;
589  }
590
591  void VisitLabelStmt(const LabelStmt *S) {
592    RecordNextStmtCount = false;
593    // Counter tracks the block following the label.
594    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
595    CountMap[S] = BlockCount;
596    Visit(S->getSubStmt());
597  }
598
599  void VisitBreakStmt(const BreakStmt *S) {
600    RecordStmtCount(S);
601    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
602    BreakContinueStack.back().BreakCount += CurrentCount;
603    CurrentCount = 0;
604    RecordNextStmtCount = true;
605  }
606
607  void VisitContinueStmt(const ContinueStmt *S) {
608    RecordStmtCount(S);
609    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
610    BreakContinueStack.back().ContinueCount += CurrentCount;
611    CurrentCount = 0;
612    RecordNextStmtCount = true;
613  }
614
615  void VisitWhileStmt(const WhileStmt *S) {
616    RecordStmtCount(S);
617    uint64_t ParentCount = CurrentCount;
618
619    BreakContinueStack.push_back(BreakContinue());
620    // Visit the body region first so the break/continue adjustments can be
621    // included when visiting the condition.
622    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
623    CountMap[S->getBody()] = CurrentCount;
624    Visit(S->getBody());
625    uint64_t BackedgeCount = CurrentCount;
626
627    // ...then go back and propagate counts through the condition. The count
628    // at the start of the condition is the sum of the incoming edges,
629    // the backedge from the end of the loop body, and the edges from
630    // continue statements.
631    BreakContinue BC = BreakContinueStack.pop_back_val();
632    uint64_t CondCount =
633        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
634    CountMap[S->getCond()] = CondCount;
635    Visit(S->getCond());
636    setCount(BC.BreakCount + CondCount - BodyCount);
637    RecordNextStmtCount = true;
638  }
639
640  void VisitDoStmt(const DoStmt *S) {
641    RecordStmtCount(S);
642    uint64_t LoopCount = PGO.getRegionCount(S);
643
644    BreakContinueStack.push_back(BreakContinue());
645    // The count doesn't include the fallthrough from the parent scope. Add it.
646    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
647    CountMap[S->getBody()] = BodyCount;
648    Visit(S->getBody());
649    uint64_t BackedgeCount = CurrentCount;
650
651    BreakContinue BC = BreakContinueStack.pop_back_val();
652    // The count at the start of the condition is equal to the count at the
653    // end of the body, plus any continues.
654    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
655    CountMap[S->getCond()] = CondCount;
656    Visit(S->getCond());
657    setCount(BC.BreakCount + CondCount - LoopCount);
658    RecordNextStmtCount = true;
659  }
660
661  void VisitForStmt(const ForStmt *S) {
662    RecordStmtCount(S);
663    if (S->getInit())
664      Visit(S->getInit());
665
666    uint64_t ParentCount = CurrentCount;
667
668    BreakContinueStack.push_back(BreakContinue());
669    // Visit the body region first. (This is basically the same as a while
670    // loop; see further comments in VisitWhileStmt.)
671    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
672    CountMap[S->getBody()] = BodyCount;
673    Visit(S->getBody());
674    uint64_t BackedgeCount = CurrentCount;
675    BreakContinue BC = BreakContinueStack.pop_back_val();
676
677    // The increment is essentially part of the body but it needs to include
678    // the count for all the continue statements.
679    if (S->getInc()) {
680      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
681      CountMap[S->getInc()] = IncCount;
682      Visit(S->getInc());
683    }
684
685    // ...then go back and propagate counts through the condition.
686    uint64_t CondCount =
687        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
688    if (S->getCond()) {
689      CountMap[S->getCond()] = CondCount;
690      Visit(S->getCond());
691    }
692    setCount(BC.BreakCount + CondCount - BodyCount);
693    RecordNextStmtCount = true;
694  }
695
696  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
697    RecordStmtCount(S);
698    if (S->getInit())
699      Visit(S->getInit());
700    Visit(S->getLoopVarStmt());
701    Visit(S->getRangeStmt());
702    Visit(S->getBeginStmt());
703    Visit(S->getEndStmt());
704
705    uint64_t ParentCount = CurrentCount;
706    BreakContinueStack.push_back(BreakContinue());
707    // Visit the body region first. (This is basically the same as a while
708    // loop; see further comments in VisitWhileStmt.)
709    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
710    CountMap[S->getBody()] = BodyCount;
711    Visit(S->getBody());
712    uint64_t BackedgeCount = CurrentCount;
713    BreakContinue BC = BreakContinueStack.pop_back_val();
714
715    // The increment is essentially part of the body but it needs to include
716    // the count for all the continue statements.
717    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
718    CountMap[S->getInc()] = IncCount;
719    Visit(S->getInc());
720
721    // ...then go back and propagate counts through the condition.
722    uint64_t CondCount =
723        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
724    CountMap[S->getCond()] = CondCount;
725    Visit(S->getCond());
726    setCount(BC.BreakCount + CondCount - BodyCount);
727    RecordNextStmtCount = true;
728  }
729
730  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
731    RecordStmtCount(S);
732    Visit(S->getElement());
733    uint64_t ParentCount = CurrentCount;
734    BreakContinueStack.push_back(BreakContinue());
735    // Counter tracks the body of the loop.
736    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
737    CountMap[S->getBody()] = BodyCount;
738    Visit(S->getBody());
739    uint64_t BackedgeCount = CurrentCount;
740    BreakContinue BC = BreakContinueStack.pop_back_val();
741
742    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
743             BodyCount);
744    RecordNextStmtCount = true;
745  }
746
747  void VisitSwitchStmt(const SwitchStmt *S) {
748    RecordStmtCount(S);
749    if (S->getInit())
750      Visit(S->getInit());
751    Visit(S->getCond());
752    CurrentCount = 0;
753    BreakContinueStack.push_back(BreakContinue());
754    Visit(S->getBody());
755    // If the switch is inside a loop, add the continue counts.
756    BreakContinue BC = BreakContinueStack.pop_back_val();
757    if (!BreakContinueStack.empty())
758      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
759    // Counter tracks the exit block of the switch.
760    setCount(PGO.getRegionCount(S));
761    RecordNextStmtCount = true;
762  }
763
764  void VisitSwitchCase(const SwitchCase *S) {
765    RecordNextStmtCount = false;
766    // Counter for this particular case. This counts only jumps from the
767    // switch header and does not include fallthrough from the case before
768    // this one.
769    uint64_t CaseCount = PGO.getRegionCount(S);
770    setCount(CurrentCount + CaseCount);
771    // We need the count without fallthrough in the mapping, so it's more useful
772    // for branch probabilities.
773    CountMap[S] = CaseCount;
774    RecordNextStmtCount = true;
775    Visit(S->getSubStmt());
776  }
777
778  void VisitIfStmt(const IfStmt *S) {
779    RecordStmtCount(S);
780
781    if (S->isConsteval()) {
782      const Stmt *Stm = S->isNegatedConsteval() ? S->getThen() : S->getElse();
783      if (Stm)
784        Visit(Stm);
785      return;
786    }
787
788    uint64_t ParentCount = CurrentCount;
789    if (S->getInit())
790      Visit(S->getInit());
791    Visit(S->getCond());
792
793    // Counter tracks the "then" part of an if statement. The count for
794    // the "else" part, if it exists, will be calculated from this counter.
795    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
796    CountMap[S->getThen()] = ThenCount;
797    Visit(S->getThen());
798    uint64_t OutCount = CurrentCount;
799
800    uint64_t ElseCount = ParentCount - ThenCount;
801    if (S->getElse()) {
802      setCount(ElseCount);
803      CountMap[S->getElse()] = ElseCount;
804      Visit(S->getElse());
805      OutCount += CurrentCount;
806    } else
807      OutCount += ElseCount;
808    setCount(OutCount);
809    RecordNextStmtCount = true;
810  }
811
812  void VisitCXXTryStmt(const CXXTryStmt *S) {
813    RecordStmtCount(S);
814    Visit(S->getTryBlock());
815    for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
816      Visit(S->getHandler(I));
817    // Counter tracks the continuation block of the try statement.
818    setCount(PGO.getRegionCount(S));
819    RecordNextStmtCount = true;
820  }
821
822  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
823    RecordNextStmtCount = false;
824    // Counter tracks the catch statement's handler block.
825    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
826    CountMap[S] = CatchCount;
827    Visit(S->getHandlerBlock());
828  }
829
830  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
831    RecordStmtCount(E);
832    uint64_t ParentCount = CurrentCount;
833    Visit(E->getCond());
834
835    // Counter tracks the "true" part of a conditional operator. The
836    // count in the "false" part will be calculated from this counter.
837    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
838    CountMap[E->getTrueExpr()] = TrueCount;
839    Visit(E->getTrueExpr());
840    uint64_t OutCount = CurrentCount;
841
842    uint64_t FalseCount = setCount(ParentCount - TrueCount);
843    CountMap[E->getFalseExpr()] = FalseCount;
844    Visit(E->getFalseExpr());
845    OutCount += CurrentCount;
846
847    setCount(OutCount);
848    RecordNextStmtCount = true;
849  }
850
851  void VisitBinLAnd(const BinaryOperator *E) {
852    RecordStmtCount(E);
853    uint64_t ParentCount = CurrentCount;
854    Visit(E->getLHS());
855    // Counter tracks the right hand side of a logical and operator.
856    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
857    CountMap[E->getRHS()] = RHSCount;
858    Visit(E->getRHS());
859    setCount(ParentCount + RHSCount - CurrentCount);
860    RecordNextStmtCount = true;
861  }
862
863  void VisitBinLOr(const BinaryOperator *E) {
864    RecordStmtCount(E);
865    uint64_t ParentCount = CurrentCount;
866    Visit(E->getLHS());
867    // Counter tracks the right hand side of a logical or operator.
868    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
869    CountMap[E->getRHS()] = RHSCount;
870    Visit(E->getRHS());
871    setCount(ParentCount + RHSCount - CurrentCount);
872    RecordNextStmtCount = true;
873  }
874};
875} // end anonymous namespace
876
877void PGOHash::combine(HashType Type) {
878  // Check that we never combine 0 and only have six bits.
879  assert(Type && "Hash is invalid: unexpected type 0");
880  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
881
882  // Pass through MD5 if enough work has built up.
883  if (Count && Count % NumTypesPerWord == 0) {
884    using namespace llvm::support;
885    uint64_t Swapped =
886        endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
887    MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
888    Working = 0;
889  }
890
891  // Accumulate the current type.
892  ++Count;
893  Working = Working << NumBitsPerType | Type;
894}
895
896uint64_t PGOHash::finalize() {
897  // Use Working as the hash directly if we never used MD5.
898  if (Count <= NumTypesPerWord)
899    // No need to byte swap here, since none of the math was endian-dependent.
900    // This number will be byte-swapped as required on endianness transitions,
901    // so we will see the same value on the other side.
902    return Working;
903
904  // Check for remaining work in Working.
905  if (Working) {
906    // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
907    // is buggy because it converts a uint64_t into an array of uint8_t.
908    if (HashVersion < PGO_HASH_V3) {
909      MD5.update({(uint8_t)Working});
910    } else {
911      using namespace llvm::support;
912      uint64_t Swapped =
913          endian::byte_swap<uint64_t, llvm::endianness::little>(Working);
914      MD5.update(llvm::ArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
915    }
916  }
917
918  // Finalize the MD5 and return the hash.
919  llvm::MD5::MD5Result Result;
920  MD5.final(Result);
921  return Result.low();
922}
923
924void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
925  const Decl *D = GD.getDecl();
926  if (!D->hasBody())
927    return;
928
929  // Skip CUDA/HIP kernel launch stub functions.
930  if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
931      D->hasAttr<CUDAGlobalAttr>())
932    return;
933
934  bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
935  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
936  if (!InstrumentRegions && !PGOReader)
937    return;
938  if (D->isImplicit())
939    return;
940  // Constructors and destructors may be represented by several functions in IR.
941  // If so, instrument only base variant, others are implemented by delegation
942  // to the base one, it would be counted twice otherwise.
943  if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
944    if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
945      if (GD.getCtorType() != Ctor_Base &&
946          CodeGenFunction::IsConstructorDelegationValid(CCD))
947        return;
948  }
949  if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
950    return;
951
952  CGM.ClearUnusedCoverageMapping(D);
953  if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
954    return;
955  if (Fn->hasFnAttribute(llvm::Attribute::SkipProfile))
956    return;
957
958  setFuncName(Fn);
959
960  mapRegionCounters(D);
961  if (CGM.getCodeGenOpts().CoverageMapping)
962    emitCounterRegionMapping(D);
963  if (PGOReader) {
964    SourceManager &SM = CGM.getContext().getSourceManager();
965    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
966    computeRegionCounts(D);
967    applyFunctionAttributes(PGOReader, Fn);
968  }
969}
970
971void CodeGenPGO::mapRegionCounters(const Decl *D) {
972  // Use the latest hash version when inserting instrumentation, but use the
973  // version in the indexed profile if we're reading PGO data.
974  PGOHashVersion HashVersion = PGO_HASH_LATEST;
975  uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
976  if (auto *PGOReader = CGM.getPGOReader()) {
977    HashVersion = getPGOHashVersion(PGOReader, CGM);
978    ProfileVersion = PGOReader->getVersion();
979  }
980
981  // If MC/DC is enabled, set the MaxConditions to a preset value. Otherwise,
982  // set it to zero. This value impacts the number of conditions accepted in a
983  // given boolean expression, which impacts the size of the bitmap used to
984  // track test vector execution for that boolean expression.  Because the
985  // bitmap scales exponentially (2^n) based on the number of conditions seen,
986  // the maximum value is hard-coded at 6 conditions, which is more than enough
987  // for most embedded applications. Setting a maximum value prevents the
988  // bitmap footprint from growing too large without the user's knowledge. In
989  // the future, this value could be adjusted with a command-line option.
990  unsigned MCDCMaxConditions = (CGM.getCodeGenOpts().MCDCCoverage) ? 6 : 0;
991
992  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
993  RegionMCDCBitmapMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
994  MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap,
995                           *RegionMCDCBitmapMap, MCDCMaxConditions,
996                           CGM.getDiags());
997  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
998    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
999  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1000    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
1001  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1002    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
1003  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1004    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
1005  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
1006  NumRegionCounters = Walker.NextCounter;
1007  MCDCBitmapBytes = Walker.NextMCDCBitmapIdx;
1008  FunctionHash = Walker.Hash.finalize();
1009}
1010
1011bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
1012  if (!D->getBody())
1013    return true;
1014
1015  // Skip host-only functions in the CUDA device compilation and device-only
1016  // functions in the host compilation. Just roughly filter them out based on
1017  // the function attributes. If there are effectively host-only or device-only
1018  // ones, their coverage mapping may still be generated.
1019  if (CGM.getLangOpts().CUDA &&
1020      ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
1021        !D->hasAttr<CUDAGlobalAttr>()) ||
1022       (!CGM.getLangOpts().CUDAIsDevice &&
1023        (D->hasAttr<CUDAGlobalAttr>() ||
1024         (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
1025    return true;
1026
1027  // Don't map the functions in system headers.
1028  const auto &SM = CGM.getContext().getSourceManager();
1029  auto Loc = D->getBody()->getBeginLoc();
1030  return !SystemHeadersCoverage && SM.isInSystemHeader(Loc);
1031}
1032
1033void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
1034  if (skipRegionMappingForDecl(D))
1035    return;
1036
1037  std::string CoverageMapping;
1038  llvm::raw_string_ostream OS(CoverageMapping);
1039  RegionCondIDMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
1040  CoverageMappingGen MappingGen(
1041      *CGM.getCoverageMapping(), CGM.getContext().getSourceManager(),
1042      CGM.getLangOpts(), RegionCounterMap.get(), RegionMCDCBitmapMap.get(),
1043      RegionCondIDMap.get());
1044  MappingGen.emitCounterMapping(D, OS);
1045  OS.flush();
1046
1047  if (CoverageMapping.empty())
1048    return;
1049
1050  CGM.getCoverageMapping()->addFunctionMappingRecord(
1051      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
1052}
1053
1054void
1055CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
1056                                    llvm::GlobalValue::LinkageTypes Linkage) {
1057  if (skipRegionMappingForDecl(D))
1058    return;
1059
1060  std::string CoverageMapping;
1061  llvm::raw_string_ostream OS(CoverageMapping);
1062  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
1063                                CGM.getContext().getSourceManager(),
1064                                CGM.getLangOpts());
1065  MappingGen.emitEmptyMapping(D, OS);
1066  OS.flush();
1067
1068  if (CoverageMapping.empty())
1069    return;
1070
1071  setFuncName(Name, Linkage);
1072  CGM.getCoverageMapping()->addFunctionMappingRecord(
1073      FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
1074}
1075
1076void CodeGenPGO::computeRegionCounts(const Decl *D) {
1077  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
1078  ComputeRegionCounts Walker(*StmtCountMap, *this);
1079  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
1080    Walker.VisitFunctionDecl(FD);
1081  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
1082    Walker.VisitObjCMethodDecl(MD);
1083  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
1084    Walker.VisitBlockDecl(BD);
1085  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
1086    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
1087}
1088
1089void
1090CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
1091                                    llvm::Function *Fn) {
1092  if (!haveRegionCounts())
1093    return;
1094
1095  uint64_t FunctionCount = getRegionCount(nullptr);
1096  Fn->setEntryCount(FunctionCount);
1097}
1098
1099void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
1100                                      llvm::Value *StepV) {
1101  if (!RegionCounterMap || !Builder.GetInsertBlock())
1102    return;
1103
1104  unsigned Counter = (*RegionCounterMap)[S];
1105
1106  llvm::Value *Args[] = {FuncNameVar,
1107                         Builder.getInt64(FunctionHash),
1108                         Builder.getInt32(NumRegionCounters),
1109                         Builder.getInt32(Counter), StepV};
1110  if (!StepV)
1111    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
1112                       ArrayRef(Args, 4));
1113  else
1114    Builder.CreateCall(
1115        CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
1116        ArrayRef(Args));
1117}
1118
1119bool CodeGenPGO::canEmitMCDCCoverage(const CGBuilderTy &Builder) {
1120  return (CGM.getCodeGenOpts().hasProfileClangInstr() &&
1121          CGM.getCodeGenOpts().MCDCCoverage && Builder.GetInsertBlock());
1122}
1123
1124void CodeGenPGO::emitMCDCParameters(CGBuilderTy &Builder) {
1125  if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1126    return;
1127
1128  auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1129
1130  // Emit intrinsic representing MCDC bitmap parameters at function entry.
1131  // This is used by the instrumentation pass, but it isn't actually lowered to
1132  // anything.
1133  llvm::Value *Args[3] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1134                          Builder.getInt64(FunctionHash),
1135                          Builder.getInt32(MCDCBitmapBytes)};
1136  Builder.CreateCall(
1137      CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_parameters), Args);
1138}
1139
1140void CodeGenPGO::emitMCDCTestVectorBitmapUpdate(CGBuilderTy &Builder,
1141                                                const Expr *S,
1142                                                Address MCDCCondBitmapAddr) {
1143  if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1144    return;
1145
1146  S = S->IgnoreParens();
1147
1148  auto ExprMCDCBitmapMapIterator = RegionMCDCBitmapMap->find(S);
1149  if (ExprMCDCBitmapMapIterator == RegionMCDCBitmapMap->end())
1150    return;
1151
1152  // Extract the ID of the global bitmap associated with this expression.
1153  unsigned MCDCTestVectorBitmapID = ExprMCDCBitmapMapIterator->second;
1154  auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1155
1156  // Emit intrinsic responsible for updating the global bitmap corresponding to
1157  // a boolean expression. The index being set is based on the value loaded
1158  // from a pointer to a dedicated temporary value on the stack that is itself
1159  // updated via emitMCDCCondBitmapReset() and emitMCDCCondBitmapUpdate(). The
1160  // index represents an executed test vector.
1161  llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1162                          Builder.getInt64(FunctionHash),
1163                          Builder.getInt32(MCDCBitmapBytes),
1164                          Builder.getInt32(MCDCTestVectorBitmapID),
1165                          MCDCCondBitmapAddr.getPointer()};
1166  Builder.CreateCall(
1167      CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_tvbitmap_update), Args);
1168}
1169
1170void CodeGenPGO::emitMCDCCondBitmapReset(CGBuilderTy &Builder, const Expr *S,
1171                                         Address MCDCCondBitmapAddr) {
1172  if (!canEmitMCDCCoverage(Builder) || !RegionMCDCBitmapMap)
1173    return;
1174
1175  S = S->IgnoreParens();
1176
1177  if (RegionMCDCBitmapMap->find(S) == RegionMCDCBitmapMap->end())
1178    return;
1179
1180  // Emit intrinsic that resets a dedicated temporary value on the stack to 0.
1181  Builder.CreateStore(Builder.getInt32(0), MCDCCondBitmapAddr);
1182}
1183
1184void CodeGenPGO::emitMCDCCondBitmapUpdate(CGBuilderTy &Builder, const Expr *S,
1185                                          Address MCDCCondBitmapAddr,
1186                                          llvm::Value *Val) {
1187  if (!canEmitMCDCCoverage(Builder) || !RegionCondIDMap)
1188    return;
1189
1190  // Even though, for simplicity, parentheses and unary logical-NOT operators
1191  // are considered part of their underlying condition for both MC/DC and
1192  // branch coverage, the condition IDs themselves are assigned and tracked
1193  // using the underlying condition itself.  This is done solely for
1194  // consistency since parentheses and logical-NOTs are ignored when checking
1195  // whether the condition is actually an instrumentable condition. This can
1196  // also make debugging a bit easier.
1197  S = CodeGenFunction::stripCond(S);
1198
1199  auto ExprMCDCConditionIDMapIterator = RegionCondIDMap->find(S);
1200  if (ExprMCDCConditionIDMapIterator == RegionCondIDMap->end())
1201    return;
1202
1203  // Extract the ID of the condition we are setting in the bitmap.
1204  unsigned CondID = ExprMCDCConditionIDMapIterator->second;
1205  assert(CondID > 0 && "Condition has no ID!");
1206
1207  auto *I8PtrTy = llvm::PointerType::getUnqual(CGM.getLLVMContext());
1208
1209  // Emit intrinsic that updates a dedicated temporary value on the stack after
1210  // a condition is evaluated. After the set of conditions has been updated,
1211  // the resulting value is used to update the boolean expression's bitmap.
1212  llvm::Value *Args[5] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
1213                          Builder.getInt64(FunctionHash),
1214                          Builder.getInt32(CondID - 1),
1215                          MCDCCondBitmapAddr.getPointer(), Val};
1216  Builder.CreateCall(
1217      CGM.getIntrinsic(llvm::Intrinsic::instrprof_mcdc_condbitmap_update),
1218      Args);
1219}
1220
1221void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
1222  if (CGM.getCodeGenOpts().hasProfileClangInstr())
1223    M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
1224                    uint32_t(EnableValueProfiling));
1225}
1226
1227// This method either inserts a call to the profile run-time during
1228// instrumentation or puts profile data into metadata for PGO use.
1229void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
1230    llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
1231
1232  if (!EnableValueProfiling)
1233    return;
1234
1235  if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
1236    return;
1237
1238  if (isa<llvm::Constant>(ValuePtr))
1239    return;
1240
1241  bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
1242  if (InstrumentValueSites && RegionCounterMap) {
1243    auto BuilderInsertPoint = Builder.saveIP();
1244    Builder.SetInsertPoint(ValueSite);
1245    llvm::Value *Args[5] = {
1246        FuncNameVar,
1247        Builder.getInt64(FunctionHash),
1248        Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
1249        Builder.getInt32(ValueKind),
1250        Builder.getInt32(NumValueSites[ValueKind]++)
1251    };
1252    Builder.CreateCall(
1253        CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
1254    Builder.restoreIP(BuilderInsertPoint);
1255    return;
1256  }
1257
1258  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1259  if (PGOReader && haveRegionCounts()) {
1260    // We record the top most called three functions at each call site.
1261    // Profile metadata contains "VP" string identifying this metadata
1262    // as value profiling data, then a uint32_t value for the value profiling
1263    // kind, a uint64_t value for the total number of times the call is
1264    // executed, followed by the function hash and execution count (uint64_t)
1265    // pairs for each function.
1266    if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1267      return;
1268
1269    llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1270                            (llvm::InstrProfValueKind)ValueKind,
1271                            NumValueSites[ValueKind]);
1272
1273    NumValueSites[ValueKind]++;
1274  }
1275}
1276
1277void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1278                                  bool IsInMainFile) {
1279  CGM.getPGOStats().addVisited(IsInMainFile);
1280  RegionCounts.clear();
1281  llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1282      PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1283  if (auto E = RecordExpected.takeError()) {
1284    auto IPE = std::get<0>(llvm::InstrProfError::take(std::move(E)));
1285    if (IPE == llvm::instrprof_error::unknown_function)
1286      CGM.getPGOStats().addMissing(IsInMainFile);
1287    else if (IPE == llvm::instrprof_error::hash_mismatch)
1288      CGM.getPGOStats().addMismatched(IsInMainFile);
1289    else if (IPE == llvm::instrprof_error::malformed)
1290      // TODO: Consider a more specific warning for this case.
1291      CGM.getPGOStats().addMismatched(IsInMainFile);
1292    return;
1293  }
1294  ProfRecord =
1295      std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1296  RegionCounts = ProfRecord->Counts;
1297}
1298
1299/// Calculate what to divide by to scale weights.
1300///
1301/// Given the maximum weight, calculate a divisor that will scale all the
1302/// weights to strictly less than UINT32_MAX.
1303static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1304  return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1305}
1306
1307/// Scale an individual branch weight (and add 1).
1308///
1309/// Scale a 64-bit weight down to 32-bits using \c Scale.
1310///
1311/// According to Laplace's Rule of Succession, it is better to compute the
1312/// weight based on the count plus 1, so universally add 1 to the value.
1313///
1314/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1315/// greater than \c Weight.
1316static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1317  assert(Scale && "scale by 0?");
1318  uint64_t Scaled = Weight / Scale + 1;
1319  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1320  return Scaled;
1321}
1322
1323llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1324                                                    uint64_t FalseCount) const {
1325  // Check for empty weights.
1326  if (!TrueCount && !FalseCount)
1327    return nullptr;
1328
1329  // Calculate how to scale down to 32-bits.
1330  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1331
1332  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1333  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1334                                      scaleBranchWeight(FalseCount, Scale));
1335}
1336
1337llvm::MDNode *
1338CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1339  // We need at least two elements to create meaningful weights.
1340  if (Weights.size() < 2)
1341    return nullptr;
1342
1343  // Check for empty weights.
1344  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1345  if (MaxWeight == 0)
1346    return nullptr;
1347
1348  // Calculate how to scale down to 32-bits.
1349  uint64_t Scale = calculateWeightScale(MaxWeight);
1350
1351  SmallVector<uint32_t, 16> ScaledWeights;
1352  ScaledWeights.reserve(Weights.size());
1353  for (uint64_t W : Weights)
1354    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1355
1356  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1357  return MDHelper.createBranchWeights(ScaledWeights);
1358}
1359
1360llvm::MDNode *
1361CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1362                                             uint64_t LoopCount) const {
1363  if (!PGO.haveRegionCounts())
1364    return nullptr;
1365  std::optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1366  if (!CondCount || *CondCount == 0)
1367    return nullptr;
1368  return createProfileWeights(LoopCount,
1369                              std::max(*CondCount, LoopCount) - LoopCount);
1370}
1371