1//===- MergeFunctions.cpp - Merge identical functions ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass looks for equivalent functions that are mergable and folds them.
10//
11// Order relation is defined on set of functions. It was made through
12// special function comparison procedure that returns
13// 0 when functions are equal,
14// -1 when Left function is less than right function, and
15// 1 for opposite case. We need total-ordering, so we need to maintain
16// four properties on the functions set:
17// a <= a (reflexivity)
18// if a <= b and b <= a then a = b (antisymmetry)
19// if a <= b and b <= c then a <= c (transitivity).
20// for all a and b: a <= b or b <= a (totality).
21//
22// Comparison iterates through each instruction in each basic block.
23// Functions are kept on binary tree. For each new function F we perform
24// lookup in binary tree.
25// In practice it works the following way:
26// -- We define Function* container class with custom "operator<" (FunctionPtr).
27// -- "FunctionPtr" instances are stored in std::set collection, so every
28//    std::set::insert operation will give you result in log(N) time.
29//
30// As an optimization, a hash of the function structure is calculated first, and
31// two functions are only compared if they have the same hash. This hash is
32// cheap to compute, and has the property that if function F == G according to
33// the comparison function, then hash(F) == hash(G). This consistency property
34// is critical to ensuring all possible merging opportunities are exploited.
35// Collisions in the hash affect the speed of the pass but not the correctness
36// or determinism of the resulting transformation.
37//
38// When a match is found the functions are folded. If both functions are
39// overridable, we move the functionality into a new internal function and
40// leave two overridable thunks to it.
41//
42//===----------------------------------------------------------------------===//
43//
44// Future work:
45//
46// * virtual functions.
47//
48// Many functions have their address taken by the virtual function table for
49// the object they belong to. However, as long as it's only used for a lookup
50// and call, this is irrelevant, and we'd like to fold such functions.
51//
52// * be smarter about bitcasts.
53//
54// In order to fold functions, we will sometimes add either bitcast instructions
55// or bitcast constant expressions. Unfortunately, this can confound further
56// analysis since the two functions differ where one has a bitcast and the
57// other doesn't. We should learn to look through bitcasts.
58//
59// * Compare complex types with pointer types inside.
60// * Compare cross-reference cases.
61// * Compare complex expressions.
62//
63// All the three issues above could be described as ability to prove that
64// fA == fB == fC == fE == fF == fG in example below:
65//
66//  void fA() {
67//    fB();
68//  }
69//  void fB() {
70//    fA();
71//  }
72//
73//  void fE() {
74//    fF();
75//  }
76//  void fF() {
77//    fG();
78//  }
79//  void fG() {
80//    fE();
81//  }
82//
83// Simplest cross-reference case (fA <--> fB) was implemented in previous
84// versions of MergeFunctions, though it presented only in two function pairs
85// in test-suite (that counts >50k functions)
86// Though possibility to detect complex cross-referencing (e.g.: A->B->C->D->A)
87// could cover much more cases.
88//
89//===----------------------------------------------------------------------===//
90
91#include "llvm/Transforms/IPO/MergeFunctions.h"
92#include "llvm/ADT/ArrayRef.h"
93#include "llvm/ADT/SmallVector.h"
94#include "llvm/ADT/Statistic.h"
95#include "llvm/IR/Argument.h"
96#include "llvm/IR/BasicBlock.h"
97#include "llvm/IR/Constant.h"
98#include "llvm/IR/Constants.h"
99#include "llvm/IR/DebugInfoMetadata.h"
100#include "llvm/IR/DebugLoc.h"
101#include "llvm/IR/DerivedTypes.h"
102#include "llvm/IR/Function.h"
103#include "llvm/IR/GlobalValue.h"
104#include "llvm/IR/IRBuilder.h"
105#include "llvm/IR/InstrTypes.h"
106#include "llvm/IR/Instruction.h"
107#include "llvm/IR/Instructions.h"
108#include "llvm/IR/IntrinsicInst.h"
109#include "llvm/IR/Module.h"
110#include "llvm/IR/StructuralHash.h"
111#include "llvm/IR/Type.h"
112#include "llvm/IR/Use.h"
113#include "llvm/IR/User.h"
114#include "llvm/IR/Value.h"
115#include "llvm/IR/ValueHandle.h"
116#include "llvm/Support/Casting.h"
117#include "llvm/Support/CommandLine.h"
118#include "llvm/Support/Debug.h"
119#include "llvm/Support/raw_ostream.h"
120#include "llvm/Transforms/IPO.h"
121#include "llvm/Transforms/Utils/FunctionComparator.h"
122#include "llvm/Transforms/Utils/ModuleUtils.h"
123#include <algorithm>
124#include <cassert>
125#include <iterator>
126#include <set>
127#include <utility>
128#include <vector>
129
130using namespace llvm;
131
132#define DEBUG_TYPE "mergefunc"
133
134STATISTIC(NumFunctionsMerged, "Number of functions merged");
135STATISTIC(NumThunksWritten, "Number of thunks generated");
136STATISTIC(NumAliasesWritten, "Number of aliases generated");
137STATISTIC(NumDoubleWeak, "Number of new functions created");
138
139static cl::opt<unsigned> NumFunctionsForVerificationCheck(
140    "mergefunc-verify",
141    cl::desc("How many functions in a module could be used for "
142             "MergeFunctions to pass a basic correctness check. "
143             "'0' disables this check. Works only with '-debug' key."),
144    cl::init(0), cl::Hidden);
145
146// Under option -mergefunc-preserve-debug-info we:
147// - Do not create a new function for a thunk.
148// - Retain the debug info for a thunk's parameters (and associated
149//   instructions for the debug info) from the entry block.
150//   Note: -debug will display the algorithm at work.
151// - Create debug-info for the call (to the shared implementation) made by
152//   a thunk and its return value.
153// - Erase the rest of the function, retaining the (minimally sized) entry
154//   block to create a thunk.
155// - Preserve a thunk's call site to point to the thunk even when both occur
156//   within the same translation unit, to aid debugability. Note that this
157//   behaviour differs from the underlying -mergefunc implementation which
158//   modifies the thunk's call site to point to the shared implementation
159//   when both occur within the same translation unit.
160static cl::opt<bool>
161    MergeFunctionsPDI("mergefunc-preserve-debug-info", cl::Hidden,
162                      cl::init(false),
163                      cl::desc("Preserve debug info in thunk when mergefunc "
164                               "transformations are made."));
165
166static cl::opt<bool>
167    MergeFunctionsAliases("mergefunc-use-aliases", cl::Hidden,
168                          cl::init(false),
169                          cl::desc("Allow mergefunc to create aliases"));
170
171namespace {
172
173class FunctionNode {
174  mutable AssertingVH<Function> F;
175  IRHash Hash;
176
177public:
178  // Note the hash is recalculated potentially multiple times, but it is cheap.
179  FunctionNode(Function *F) : F(F), Hash(StructuralHash(*F)) {}
180
181  Function *getFunc() const { return F; }
182  IRHash getHash() const { return Hash; }
183
184  /// Replace the reference to the function F by the function G, assuming their
185  /// implementations are equal.
186  void replaceBy(Function *G) const {
187    F = G;
188  }
189};
190
191/// MergeFunctions finds functions which will generate identical machine code,
192/// by considering all pointer types to be equivalent. Once identified,
193/// MergeFunctions will fold them by replacing a call to one to a call to a
194/// bitcast of the other.
195class MergeFunctions {
196public:
197  MergeFunctions() : FnTree(FunctionNodeCmp(&GlobalNumbers)) {
198  }
199
200  bool runOnModule(Module &M);
201
202private:
203  // The function comparison operator is provided here so that FunctionNodes do
204  // not need to become larger with another pointer.
205  class FunctionNodeCmp {
206    GlobalNumberState* GlobalNumbers;
207
208  public:
209    FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}
210
211    bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const {
212      // Order first by hashes, then full function comparison.
213      if (LHS.getHash() != RHS.getHash())
214        return LHS.getHash() < RHS.getHash();
215      FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
216      return FCmp.compare() < 0;
217    }
218  };
219  using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>;
220
221  GlobalNumberState GlobalNumbers;
222
223  /// A work queue of functions that may have been modified and should be
224  /// analyzed again.
225  std::vector<WeakTrackingVH> Deferred;
226
227  /// Set of values marked as used in llvm.used and llvm.compiler.used.
228  SmallPtrSet<GlobalValue *, 4> Used;
229
230#ifndef NDEBUG
231  /// Checks the rules of order relation introduced among functions set.
232  /// Returns true, if check has been passed, and false if failed.
233  bool doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist);
234#endif
235
236  /// Insert a ComparableFunction into the FnTree, or merge it away if it's
237  /// equal to one that's already present.
238  bool insert(Function *NewFunction);
239
240  /// Remove a Function from the FnTree and queue it up for a second sweep of
241  /// analysis.
242  void remove(Function *F);
243
244  /// Find the functions that use this Value and remove them from FnTree and
245  /// queue the functions.
246  void removeUsers(Value *V);
247
248  /// Replace all direct calls of Old with calls of New. Will bitcast New if
249  /// necessary to make types match.
250  void replaceDirectCallers(Function *Old, Function *New);
251
252  /// Merge two equivalent functions. Upon completion, G may be deleted, or may
253  /// be converted into a thunk. In either case, it should never be visited
254  /// again.
255  void mergeTwoFunctions(Function *F, Function *G);
256
257  /// Fill PDIUnrelatedWL with instructions from the entry block that are
258  /// unrelated to parameter related debug info.
259  void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock,
260                                 std::vector<Instruction *> &PDIUnrelatedWL);
261
262  /// Erase the rest of the CFG (i.e. barring the entry block).
263  void eraseTail(Function *G);
264
265  /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
266  /// parameter debug info, from the entry block.
267  void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL);
268
269  /// Replace G with a simple tail call to bitcast(F). Also (unless
270  /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
271  /// delete G.
272  void writeThunk(Function *F, Function *G);
273
274  // Replace G with an alias to F (deleting function G)
275  void writeAlias(Function *F, Function *G);
276
277  // Replace G with an alias to F if possible, or a thunk to F if possible.
278  // Returns false if neither is the case.
279  bool writeThunkOrAlias(Function *F, Function *G);
280
281  /// Replace function F with function G in the function tree.
282  void replaceFunctionInTree(const FunctionNode &FN, Function *G);
283
284  /// The set of all distinct functions. Use the insert() and remove() methods
285  /// to modify it. The map allows efficient lookup and deferring of Functions.
286  FnTreeType FnTree;
287
288  // Map functions to the iterators of the FunctionNode which contains them
289  // in the FnTree. This must be updated carefully whenever the FnTree is
290  // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
291  // dangling iterators into FnTree. The invariant that preserves this is that
292  // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
293  DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree;
294};
295} // end anonymous namespace
296
297PreservedAnalyses MergeFunctionsPass::run(Module &M,
298                                          ModuleAnalysisManager &AM) {
299  MergeFunctions MF;
300  if (!MF.runOnModule(M))
301    return PreservedAnalyses::all();
302  return PreservedAnalyses::none();
303}
304
305#ifndef NDEBUG
306bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) {
307  if (const unsigned Max = NumFunctionsForVerificationCheck) {
308    unsigned TripleNumber = 0;
309    bool Valid = true;
310
311    dbgs() << "MERGEFUNC-VERIFY: Started for first " << Max << " functions.\n";
312
313    unsigned i = 0;
314    for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(),
315                                               E = Worklist.end();
316         I != E && i < Max; ++I, ++i) {
317      unsigned j = i;
318      for (std::vector<WeakTrackingVH>::iterator J = I; J != E && j < Max;
319           ++J, ++j) {
320        Function *F1 = cast<Function>(*I);
321        Function *F2 = cast<Function>(*J);
322        int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare();
323        int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare();
324
325        // If F1 <= F2, then F2 >= F1, otherwise report failure.
326        if (Res1 != -Res2) {
327          dbgs() << "MERGEFUNC-VERIFY: Non-symmetric; triple: " << TripleNumber
328                 << "\n";
329          dbgs() << *F1 << '\n' << *F2 << '\n';
330          Valid = false;
331        }
332
333        if (Res1 == 0)
334          continue;
335
336        unsigned k = j;
337        for (std::vector<WeakTrackingVH>::iterator K = J; K != E && k < Max;
338             ++k, ++K, ++TripleNumber) {
339          if (K == J)
340            continue;
341
342          Function *F3 = cast<Function>(*K);
343          int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare();
344          int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare();
345
346          bool Transitive = true;
347
348          if (Res1 != 0 && Res1 == Res4) {
349            // F1 > F2, F2 > F3 => F1 > F3
350            Transitive = Res3 == Res1;
351          } else if (Res3 != 0 && Res3 == -Res4) {
352            // F1 > F3, F3 > F2 => F1 > F2
353            Transitive = Res3 == Res1;
354          } else if (Res4 != 0 && -Res3 == Res4) {
355            // F2 > F3, F3 > F1 => F2 > F1
356            Transitive = Res4 == -Res1;
357          }
358
359          if (!Transitive) {
360            dbgs() << "MERGEFUNC-VERIFY: Non-transitive; triple: "
361                   << TripleNumber << "\n";
362            dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", "
363                   << Res4 << "\n";
364            dbgs() << *F1 << '\n' << *F2 << '\n' << *F3 << '\n';
365            Valid = false;
366          }
367        }
368      }
369    }
370
371    dbgs() << "MERGEFUNC-VERIFY: " << (Valid ? "Passed." : "Failed.") << "\n";
372    return Valid;
373  }
374  return true;
375}
376#endif
377
378/// Check whether \p F has an intrinsic which references
379/// distinct metadata as an operand. The most common
380/// instance of this would be CFI checks for function-local types.
381static bool hasDistinctMetadataIntrinsic(const Function &F) {
382  for (const BasicBlock &BB : F) {
383    for (const Instruction &I : BB.instructionsWithoutDebug()) {
384      if (!isa<IntrinsicInst>(&I))
385        continue;
386
387      for (Value *Op : I.operands()) {
388        auto *MDL = dyn_cast<MetadataAsValue>(Op);
389        if (!MDL)
390          continue;
391        if (MDNode *N = dyn_cast<MDNode>(MDL->getMetadata()))
392          if (N->isDistinct())
393            return true;
394      }
395    }
396  }
397  return false;
398}
399
400/// Check whether \p F is eligible for function merging.
401static bool isEligibleForMerging(Function &F) {
402  return !F.isDeclaration() && !F.hasAvailableExternallyLinkage() &&
403         !hasDistinctMetadataIntrinsic(F);
404}
405
406bool MergeFunctions::runOnModule(Module &M) {
407  bool Changed = false;
408
409  SmallVector<GlobalValue *, 4> UsedV;
410  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false);
411  collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true);
412  Used.insert(UsedV.begin(), UsedV.end());
413
414  // All functions in the module, ordered by hash. Functions with a unique
415  // hash value are easily eliminated.
416  std::vector<std::pair<IRHash, Function *>> HashedFuncs;
417  for (Function &Func : M) {
418    if (isEligibleForMerging(Func)) {
419      HashedFuncs.push_back({StructuralHash(Func), &Func});
420    }
421  }
422
423  llvm::stable_sort(HashedFuncs, less_first());
424
425  auto S = HashedFuncs.begin();
426  for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) {
427    // If the hash value matches the previous value or the next one, we must
428    // consider merging it. Otherwise it is dropped and never considered again.
429    if ((I != S && std::prev(I)->first == I->first) ||
430        (std::next(I) != IE && std::next(I)->first == I->first) ) {
431      Deferred.push_back(WeakTrackingVH(I->second));
432    }
433  }
434
435  do {
436    std::vector<WeakTrackingVH> Worklist;
437    Deferred.swap(Worklist);
438
439    LLVM_DEBUG(doFunctionalCheck(Worklist));
440
441    LLVM_DEBUG(dbgs() << "size of module: " << M.size() << '\n');
442    LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
443
444    // Insert functions and merge them.
445    for (WeakTrackingVH &I : Worklist) {
446      if (!I)
447        continue;
448      Function *F = cast<Function>(I);
449      if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage()) {
450        Changed |= insert(F);
451      }
452    }
453    LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n');
454  } while (!Deferred.empty());
455
456  FnTree.clear();
457  FNodesInTree.clear();
458  GlobalNumbers.clear();
459  Used.clear();
460
461  return Changed;
462}
463
464// Replace direct callers of Old with New.
465void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
466  for (Use &U : llvm::make_early_inc_range(Old->uses())) {
467    CallBase *CB = dyn_cast<CallBase>(U.getUser());
468    if (CB && CB->isCallee(&U)) {
469      // Do not copy attributes from the called function to the call-site.
470      // Function comparison ensures that the attributes are the same up to
471      // type congruences in byval(), in which case we need to keep the byval
472      // type of the call-site, not the callee function.
473      remove(CB->getFunction());
474      U.set(New);
475    }
476  }
477}
478
479// Helper for writeThunk,
480// Selects proper bitcast operation,
481// but a bit simpler then CastInst::getCastOpcode.
482static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
483  Type *SrcTy = V->getType();
484  if (SrcTy->isStructTy()) {
485    assert(DestTy->isStructTy());
486    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
487    Value *Result = PoisonValue::get(DestTy);
488    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
489      Value *Element =
490          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
491                     DestTy->getStructElementType(I));
492
493      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
494    }
495    return Result;
496  }
497  assert(!DestTy->isStructTy());
498  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
499    return Builder.CreateIntToPtr(V, DestTy);
500  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
501    return Builder.CreatePtrToInt(V, DestTy);
502  else
503    return Builder.CreateBitCast(V, DestTy);
504}
505
506// Erase the instructions in PDIUnrelatedWL as they are unrelated to the
507// parameter debug info, from the entry block.
508void MergeFunctions::eraseInstsUnrelatedToPDI(
509    std::vector<Instruction *> &PDIUnrelatedWL) {
510  LLVM_DEBUG(
511      dbgs() << " Erasing instructions (in reverse order of appearance in "
512                "entry block) unrelated to parameter debug info from entry "
513                "block: {\n");
514  while (!PDIUnrelatedWL.empty()) {
515    Instruction *I = PDIUnrelatedWL.back();
516    LLVM_DEBUG(dbgs() << "  Deleting Instruction: ");
517    LLVM_DEBUG(I->print(dbgs()));
518    LLVM_DEBUG(dbgs() << "\n");
519    I->eraseFromParent();
520    PDIUnrelatedWL.pop_back();
521  }
522  LLVM_DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter "
523                       "debug info from entry block. \n");
524}
525
526// Reduce G to its entry block.
527void MergeFunctions::eraseTail(Function *G) {
528  std::vector<BasicBlock *> WorklistBB;
529  for (BasicBlock &BB : drop_begin(*G)) {
530    BB.dropAllReferences();
531    WorklistBB.push_back(&BB);
532  }
533  while (!WorklistBB.empty()) {
534    BasicBlock *BB = WorklistBB.back();
535    BB->eraseFromParent();
536    WorklistBB.pop_back();
537  }
538}
539
540// We are interested in the following instructions from the entry block as being
541// related to parameter debug info:
542// - @llvm.dbg.declare
543// - stores from the incoming parameters to locations on the stack-frame
544// - allocas that create these locations on the stack-frame
545// - @llvm.dbg.value
546// - the entry block's terminator
547// The rest are unrelated to debug info for the parameters; fill up
548// PDIUnrelatedWL with such instructions.
549void MergeFunctions::filterInstsUnrelatedToPDI(
550    BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) {
551  std::set<Instruction *> PDIRelated;
552  for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end();
553       BI != BIE; ++BI) {
554    if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) {
555      LLVM_DEBUG(dbgs() << " Deciding: ");
556      LLVM_DEBUG(BI->print(dbgs()));
557      LLVM_DEBUG(dbgs() << "\n");
558      DILocalVariable *DILocVar = DVI->getVariable();
559      if (DILocVar->isParameter()) {
560        LLVM_DEBUG(dbgs() << "  Include (parameter): ");
561        LLVM_DEBUG(BI->print(dbgs()));
562        LLVM_DEBUG(dbgs() << "\n");
563        PDIRelated.insert(&*BI);
564      } else {
565        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
566        LLVM_DEBUG(BI->print(dbgs()));
567        LLVM_DEBUG(dbgs() << "\n");
568      }
569    } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) {
570      LLVM_DEBUG(dbgs() << " Deciding: ");
571      LLVM_DEBUG(BI->print(dbgs()));
572      LLVM_DEBUG(dbgs() << "\n");
573      DILocalVariable *DILocVar = DDI->getVariable();
574      if (DILocVar->isParameter()) {
575        LLVM_DEBUG(dbgs() << "  Parameter: ");
576        LLVM_DEBUG(DILocVar->print(dbgs()));
577        AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress());
578        if (AI) {
579          LLVM_DEBUG(dbgs() << "  Processing alloca users: ");
580          LLVM_DEBUG(dbgs() << "\n");
581          for (User *U : AI->users()) {
582            if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
583              if (Value *Arg = SI->getValueOperand()) {
584                if (isa<Argument>(Arg)) {
585                  LLVM_DEBUG(dbgs() << "  Include: ");
586                  LLVM_DEBUG(AI->print(dbgs()));
587                  LLVM_DEBUG(dbgs() << "\n");
588                  PDIRelated.insert(AI);
589                  LLVM_DEBUG(dbgs() << "   Include (parameter): ");
590                  LLVM_DEBUG(SI->print(dbgs()));
591                  LLVM_DEBUG(dbgs() << "\n");
592                  PDIRelated.insert(SI);
593                  LLVM_DEBUG(dbgs() << "  Include: ");
594                  LLVM_DEBUG(BI->print(dbgs()));
595                  LLVM_DEBUG(dbgs() << "\n");
596                  PDIRelated.insert(&*BI);
597                } else {
598                  LLVM_DEBUG(dbgs() << "   Delete (!parameter): ");
599                  LLVM_DEBUG(SI->print(dbgs()));
600                  LLVM_DEBUG(dbgs() << "\n");
601                }
602              }
603            } else {
604              LLVM_DEBUG(dbgs() << "   Defer: ");
605              LLVM_DEBUG(U->print(dbgs()));
606              LLVM_DEBUG(dbgs() << "\n");
607            }
608          }
609        } else {
610          LLVM_DEBUG(dbgs() << "  Delete (alloca NULL): ");
611          LLVM_DEBUG(BI->print(dbgs()));
612          LLVM_DEBUG(dbgs() << "\n");
613        }
614      } else {
615        LLVM_DEBUG(dbgs() << "  Delete (!parameter): ");
616        LLVM_DEBUG(BI->print(dbgs()));
617        LLVM_DEBUG(dbgs() << "\n");
618      }
619    } else if (BI->isTerminator() && &*BI == GEntryBlock->getTerminator()) {
620      LLVM_DEBUG(dbgs() << " Will Include Terminator: ");
621      LLVM_DEBUG(BI->print(dbgs()));
622      LLVM_DEBUG(dbgs() << "\n");
623      PDIRelated.insert(&*BI);
624    } else {
625      LLVM_DEBUG(dbgs() << " Defer: ");
626      LLVM_DEBUG(BI->print(dbgs()));
627      LLVM_DEBUG(dbgs() << "\n");
628    }
629  }
630  LLVM_DEBUG(
631      dbgs()
632      << " Report parameter debug info related/related instructions: {\n");
633  for (Instruction &I : *GEntryBlock) {
634    if (PDIRelated.find(&I) == PDIRelated.end()) {
635      LLVM_DEBUG(dbgs() << "  !PDIRelated: ");
636      LLVM_DEBUG(I.print(dbgs()));
637      LLVM_DEBUG(dbgs() << "\n");
638      PDIUnrelatedWL.push_back(&I);
639    } else {
640      LLVM_DEBUG(dbgs() << "   PDIRelated: ");
641      LLVM_DEBUG(I.print(dbgs()));
642      LLVM_DEBUG(dbgs() << "\n");
643    }
644  }
645  LLVM_DEBUG(dbgs() << " }\n");
646}
647
648/// Whether this function may be replaced by a forwarding thunk.
649static bool canCreateThunkFor(Function *F) {
650  if (F->isVarArg())
651    return false;
652
653  // Don't merge tiny functions using a thunk, since it can just end up
654  // making the function larger.
655  if (F->size() == 1) {
656    if (F->front().sizeWithoutDebug() < 2) {
657      LLVM_DEBUG(dbgs() << "canCreateThunkFor: " << F->getName()
658                        << " is too small to bother creating a thunk for\n");
659      return false;
660    }
661  }
662  return true;
663}
664
665/// Copy metadata from one function to another.
666static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) {
667  if (MDNode *MD = From->getMetadata(Key)) {
668    To->setMetadata(Key, MD);
669  }
670}
671
672// Replace G with a simple tail call to bitcast(F). Also (unless
673// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F),
674// delete G. Under MergeFunctionsPDI, we use G itself for creating
675// the thunk as we preserve the debug info (and associated instructions)
676// from G's entry block pertaining to G's incoming arguments which are
677// passed on as corresponding arguments in the call that G makes to F.
678// For better debugability, under MergeFunctionsPDI, we do not modify G's
679// call sites to point to F even when within the same translation unit.
680void MergeFunctions::writeThunk(Function *F, Function *G) {
681  BasicBlock *GEntryBlock = nullptr;
682  std::vector<Instruction *> PDIUnrelatedWL;
683  BasicBlock *BB = nullptr;
684  Function *NewG = nullptr;
685  if (MergeFunctionsPDI) {
686    LLVM_DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new "
687                         "function as thunk; retain original: "
688                      << G->getName() << "()\n");
689    GEntryBlock = &G->getEntryBlock();
690    LLVM_DEBUG(
691        dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related "
692                  "debug info for "
693               << G->getName() << "() {\n");
694    filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL);
695    GEntryBlock->getTerminator()->eraseFromParent();
696    BB = GEntryBlock;
697  } else {
698    NewG = Function::Create(G->getFunctionType(), G->getLinkage(),
699                            G->getAddressSpace(), "", G->getParent());
700    NewG->setComdat(G->getComdat());
701    BB = BasicBlock::Create(F->getContext(), "", NewG);
702  }
703
704  IRBuilder<> Builder(BB);
705  Function *H = MergeFunctionsPDI ? G : NewG;
706  SmallVector<Value *, 16> Args;
707  unsigned i = 0;
708  FunctionType *FFTy = F->getFunctionType();
709  for (Argument &AI : H->args()) {
710    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
711    ++i;
712  }
713
714  CallInst *CI = Builder.CreateCall(F, Args);
715  ReturnInst *RI = nullptr;
716  bool isSwiftTailCall = F->getCallingConv() == CallingConv::SwiftTail &&
717                         G->getCallingConv() == CallingConv::SwiftTail;
718  CI->setTailCallKind(isSwiftTailCall ? llvm::CallInst::TCK_MustTail
719                                      : llvm::CallInst::TCK_Tail);
720  CI->setCallingConv(F->getCallingConv());
721  CI->setAttributes(F->getAttributes());
722  if (H->getReturnType()->isVoidTy()) {
723    RI = Builder.CreateRetVoid();
724  } else {
725    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
726  }
727
728  if (MergeFunctionsPDI) {
729    DISubprogram *DIS = G->getSubprogram();
730    if (DIS) {
731      DebugLoc CIDbgLoc =
732          DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS);
733      DebugLoc RIDbgLoc =
734          DILocation::get(DIS->getContext(), DIS->getScopeLine(), 0, DIS);
735      CI->setDebugLoc(CIDbgLoc);
736      RI->setDebugLoc(RIDbgLoc);
737    } else {
738      LLVM_DEBUG(
739          dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for "
740                 << G->getName() << "()\n");
741    }
742    eraseTail(G);
743    eraseInstsUnrelatedToPDI(PDIUnrelatedWL);
744    LLVM_DEBUG(
745        dbgs() << "} // End of parameter related debug info filtering for: "
746               << G->getName() << "()\n");
747  } else {
748    NewG->copyAttributesFrom(G);
749    NewG->takeName(G);
750    // Ensure CFI type metadata is propagated to the new function.
751    copyMetadataIfPresent(G, NewG, "type");
752    copyMetadataIfPresent(G, NewG, "kcfi_type");
753    removeUsers(G);
754    G->replaceAllUsesWith(NewG);
755    G->eraseFromParent();
756  }
757
758  LLVM_DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n');
759  ++NumThunksWritten;
760}
761
762// Whether this function may be replaced by an alias
763static bool canCreateAliasFor(Function *F) {
764  if (!MergeFunctionsAliases || !F->hasGlobalUnnamedAddr())
765    return false;
766
767  // We should only see linkages supported by aliases here
768  assert(F->hasLocalLinkage() || F->hasExternalLinkage()
769      || F->hasWeakLinkage() || F->hasLinkOnceLinkage());
770  return true;
771}
772
773// Replace G with an alias to F (deleting function G)
774void MergeFunctions::writeAlias(Function *F, Function *G) {
775  PointerType *PtrType = G->getType();
776  auto *GA = GlobalAlias::create(G->getValueType(), PtrType->getAddressSpace(),
777                                 G->getLinkage(), "", F, G->getParent());
778
779  const MaybeAlign FAlign = F->getAlign();
780  const MaybeAlign GAlign = G->getAlign();
781  if (FAlign || GAlign)
782    F->setAlignment(std::max(FAlign.valueOrOne(), GAlign.valueOrOne()));
783  else
784    F->setAlignment(std::nullopt);
785  GA->takeName(G);
786  GA->setVisibility(G->getVisibility());
787  GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
788
789  removeUsers(G);
790  G->replaceAllUsesWith(GA);
791  G->eraseFromParent();
792
793  LLVM_DEBUG(dbgs() << "writeAlias: " << GA->getName() << '\n');
794  ++NumAliasesWritten;
795}
796
797// Replace G with an alias to F if possible, or a thunk to F if
798// profitable. Returns false if neither is the case.
799bool MergeFunctions::writeThunkOrAlias(Function *F, Function *G) {
800  if (canCreateAliasFor(G)) {
801    writeAlias(F, G);
802    return true;
803  }
804  if (canCreateThunkFor(F)) {
805    writeThunk(F, G);
806    return true;
807  }
808  return false;
809}
810
811// Merge two equivalent functions. Upon completion, Function G is deleted.
812void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
813  if (F->isInterposable()) {
814    assert(G->isInterposable());
815
816    // Both writeThunkOrAlias() calls below must succeed, either because we can
817    // create aliases for G and NewF, or because a thunk for F is profitable.
818    // F here has the same signature as NewF below, so that's what we check.
819    if (!canCreateThunkFor(F) &&
820        (!canCreateAliasFor(F) || !canCreateAliasFor(G)))
821      return;
822
823    // Make them both thunks to the same internal function.
824    Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(),
825                                      F->getAddressSpace(), "", F->getParent());
826    NewF->copyAttributesFrom(F);
827    NewF->takeName(F);
828    // Ensure CFI type metadata is propagated to the new function.
829    copyMetadataIfPresent(F, NewF, "type");
830    copyMetadataIfPresent(F, NewF, "kcfi_type");
831    removeUsers(F);
832    F->replaceAllUsesWith(NewF);
833
834    // We collect alignment before writeThunkOrAlias that overwrites NewF and
835    // G's content.
836    const MaybeAlign NewFAlign = NewF->getAlign();
837    const MaybeAlign GAlign = G->getAlign();
838
839    writeThunkOrAlias(F, G);
840    writeThunkOrAlias(F, NewF);
841
842    if (NewFAlign || GAlign)
843      F->setAlignment(std::max(NewFAlign.valueOrOne(), GAlign.valueOrOne()));
844    else
845      F->setAlignment(std::nullopt);
846    F->setLinkage(GlobalValue::PrivateLinkage);
847    ++NumDoubleWeak;
848    ++NumFunctionsMerged;
849  } else {
850    // For better debugability, under MergeFunctionsPDI, we do not modify G's
851    // call sites to point to F even when within the same translation unit.
852    if (!G->isInterposable() && !MergeFunctionsPDI) {
853      // Functions referred to by llvm.used/llvm.compiler.used are special:
854      // there are uses of the symbol name that are not visible to LLVM,
855      // usually from inline asm.
856      if (G->hasGlobalUnnamedAddr() && !Used.contains(G)) {
857        // G might have been a key in our GlobalNumberState, and it's illegal
858        // to replace a key in ValueMap<GlobalValue *> with a non-global.
859        GlobalNumbers.erase(G);
860        // If G's address is not significant, replace it entirely.
861        removeUsers(G);
862        G->replaceAllUsesWith(F);
863      } else {
864        // Redirect direct callers of G to F. (See note on MergeFunctionsPDI
865        // above).
866        replaceDirectCallers(G, F);
867      }
868    }
869
870    // If G was internal then we may have replaced all uses of G with F. If so,
871    // stop here and delete G. There's no need for a thunk. (See note on
872    // MergeFunctionsPDI above).
873    if (G->isDiscardableIfUnused() && G->use_empty() && !MergeFunctionsPDI) {
874      G->eraseFromParent();
875      ++NumFunctionsMerged;
876      return;
877    }
878
879    if (writeThunkOrAlias(F, G)) {
880      ++NumFunctionsMerged;
881    }
882  }
883}
884
885/// Replace function F by function G.
886void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
887                                           Function *G) {
888  Function *F = FN.getFunc();
889  assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
890         "The two functions must be equal");
891
892  auto I = FNodesInTree.find(F);
893  assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
894  assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
895
896  FnTreeType::iterator IterToFNInFnTree = I->second;
897  assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
898  // Remove F -> FN and insert G -> FN
899  FNodesInTree.erase(I);
900  FNodesInTree.insert({G, IterToFNInFnTree});
901  // Replace F with G in FN, which is stored inside the FnTree.
902  FN.replaceBy(G);
903}
904
905// Ordering for functions that are equal under FunctionComparator
906static bool isFuncOrderCorrect(const Function *F, const Function *G) {
907  if (F->isInterposable() != G->isInterposable()) {
908    // Strong before weak, because the weak function may call the strong
909    // one, but not the other way around.
910    return !F->isInterposable();
911  }
912  if (F->hasLocalLinkage() != G->hasLocalLinkage()) {
913    // External before local, because we definitely have to keep the external
914    // function, but may be able to drop the local one.
915    return !F->hasLocalLinkage();
916  }
917  // Impose a total order (by name) on the replacement of functions. This is
918  // important when operating on more than one module independently to prevent
919  // cycles of thunks calling each other when the modules are linked together.
920  return F->getName() <= G->getName();
921}
922
923// Insert a ComparableFunction into the FnTree, or merge it away if equal to one
924// that was already inserted.
925bool MergeFunctions::insert(Function *NewFunction) {
926  std::pair<FnTreeType::iterator, bool> Result =
927      FnTree.insert(FunctionNode(NewFunction));
928
929  if (Result.second) {
930    assert(FNodesInTree.count(NewFunction) == 0);
931    FNodesInTree.insert({NewFunction, Result.first});
932    LLVM_DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName()
933                      << '\n');
934    return false;
935  }
936
937  const FunctionNode &OldF = *Result.first;
938
939  if (!isFuncOrderCorrect(OldF.getFunc(), NewFunction)) {
940    // Swap the two functions.
941    Function *F = OldF.getFunc();
942    replaceFunctionInTree(*Result.first, NewFunction);
943    NewFunction = F;
944    assert(OldF.getFunc() != F && "Must have swapped the functions.");
945  }
946
947  LLVM_DEBUG(dbgs() << "  " << OldF.getFunc()->getName()
948                    << " == " << NewFunction->getName() << '\n');
949
950  Function *DeleteF = NewFunction;
951  mergeTwoFunctions(OldF.getFunc(), DeleteF);
952  return true;
953}
954
955// Remove a function from FnTree. If it was already in FnTree, add
956// it to Deferred so that we'll look at it in the next round.
957void MergeFunctions::remove(Function *F) {
958  auto I = FNodesInTree.find(F);
959  if (I != FNodesInTree.end()) {
960    LLVM_DEBUG(dbgs() << "Deferred " << F->getName() << ".\n");
961    FnTree.erase(I->second);
962    // I->second has been invalidated, remove it from the FNodesInTree map to
963    // preserve the invariant.
964    FNodesInTree.erase(I);
965    Deferred.emplace_back(F);
966  }
967}
968
969// For each instruction used by the value, remove() the function that contains
970// the instruction. This should happen right before a call to RAUW.
971void MergeFunctions::removeUsers(Value *V) {
972  for (User *U : V->users())
973    if (auto *I = dyn_cast<Instruction>(U))
974      remove(I->getFunction());
975}
976