ArgumentPromotion.cpp revision 360784
1189747Ssam//===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===//
2189747Ssam//
3189747Ssam// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4189747Ssam// See https://llvm.org/LICENSE.txt for license information.
5189747Ssam// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6189747Ssam//
7189747Ssam//===----------------------------------------------------------------------===//
8189747Ssam//
9189747Ssam// This pass promotes "by reference" arguments to be "by value" arguments.  In
10189747Ssam// practice, this means looking for internal functions that have pointer
11189747Ssam// arguments.  If it can prove, through the use of alias analysis, that an
12189747Ssam// argument is *only* loaded, then it can pass the value into the function
13189747Ssam// instead of the address of the value.  This can cause recursive simplification
14189747Ssam// of code and lead to the elimination of allocas (especially in C++ template
15189747Ssam// code like the STL).
16189747Ssam//
17189747Ssam// This pass also handles aggregate arguments that are passed into a function,
18189747Ssam// scalarizing them if the elements of the aggregate are only loaded.  Note that
19189747Ssam// by default it refuses to scalarize aggregates which would require passing in
20189747Ssam// more than three operands to the function, because passing thousands of
21189747Ssam// operands for a large array or structure is unprofitable! This limit can be
22189747Ssam// configured or disabled, however.
23189747Ssam//
24189747Ssam// Note that this transformation could also be done for arguments that are only
25189747Ssam// stored to (returning the value instead), but does not currently.  This case
26189747Ssam// would be best handled when and if LLVM begins supporting multiple return
27217631Sadrian// values from functions.
28189747Ssam//
29189747Ssam//===----------------------------------------------------------------------===//
30189747Ssam
31217631Sadrian#include "llvm/Transforms/IPO/ArgumentPromotion.h"
32217631Sadrian#include "llvm/ADT/DepthFirstIterator.h"
33219393Sadrian#include "llvm/ADT/None.h"
34189747Ssam#include "llvm/ADT/Optional.h"
35189747Ssam#include "llvm/ADT/STLExtras.h"
36189747Ssam#include "llvm/ADT/SmallPtrSet.h"
37189747Ssam#include "llvm/ADT/SmallVector.h"
38189747Ssam#include "llvm/ADT/Statistic.h"
39189747Ssam#include "llvm/ADT/StringExtras.h"
40189747Ssam#include "llvm/ADT/Twine.h"
41189747Ssam#include "llvm/Analysis/AliasAnalysis.h"
42189747Ssam#include "llvm/Analysis/AssumptionCache.h"
43189747Ssam#include "llvm/Analysis/BasicAliasAnalysis.h"
44189747Ssam#include "llvm/Analysis/CGSCCPassManager.h"
45189747Ssam#include "llvm/Analysis/CallGraph.h"
46189747Ssam#include "llvm/Analysis/CallGraphSCCPass.h"
47189747Ssam#include "llvm/Analysis/LazyCallGraph.h"
48189747Ssam#include "llvm/Analysis/Loads.h"
49189747Ssam#include "llvm/Analysis/MemoryLocation.h"
50189747Ssam#include "llvm/Analysis/TargetLibraryInfo.h"
51189747Ssam#include "llvm/Analysis/TargetTransformInfo.h"
52189747Ssam#include "llvm/IR/Argument.h"
53189747Ssam#include "llvm/IR/Attributes.h"
54189747Ssam#include "llvm/IR/BasicBlock.h"
55189747Ssam#include "llvm/IR/CFG.h"
56189747Ssam#include "llvm/IR/CallSite.h"
57189747Ssam#include "llvm/IR/Constants.h"
58189747Ssam#include "llvm/IR/DataLayout.h"
59189747Ssam#include "llvm/IR/DerivedTypes.h"
60189747Ssam#include "llvm/IR/Function.h"
61189747Ssam#include "llvm/IR/IRBuilder.h"
62189747Ssam#include "llvm/IR/InstrTypes.h"
63189747Ssam#include "llvm/IR/Instruction.h"
64189747Ssam#include "llvm/IR/Instructions.h"
65189747Ssam#include "llvm/IR/Metadata.h"
66189747Ssam#include "llvm/IR/Module.h"
67189747Ssam#include "llvm/IR/NoFolder.h"
68189747Ssam#include "llvm/IR/PassManager.h"
69189747Ssam#include "llvm/IR/Type.h"
70189747Ssam#include "llvm/IR/Use.h"
71189747Ssam#include "llvm/IR/User.h"
72218764Sadrian#include "llvm/IR/Value.h"
73218764Sadrian#include "llvm/InitializePasses.h"
74218764Sadrian#include "llvm/Pass.h"
75218764Sadrian#include "llvm/Support/Casting.h"
76218764Sadrian#include "llvm/Support/Debug.h"
77218764Sadrian#include "llvm/Support/raw_ostream.h"
78218764Sadrian#include "llvm/Transforms/IPO.h"
79218764Sadrian#include <algorithm>
80218764Sadrian#include <cassert>
81218764Sadrian#include <cstdint>
82218764Sadrian#include <functional>
83218764Sadrian#include <iterator>
84218764Sadrian#include <map>
85218764Sadrian#include <set>
86218764Sadrian#include <string>
87218764Sadrian#include <utility>
88218764Sadrian#include <vector>
89218764Sadrian
90218764Sadrianusing namespace llvm;
91218764Sadrian
92218764Sadrian#define DEBUG_TYPE "argpromotion"
93218764Sadrian
94218764SadrianSTATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted");
95218764SadrianSTATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted");
96218764SadrianSTATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted");
97218764SadrianSTATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated");
98218764Sadrian
99219979Sadrian/// A vector used to hold the indices of a single GEP instruction
100189747Ssamusing IndicesVector = std::vector<uint64_t>;
101189747Ssam
102221875Sadrian/// DoPromotion - This method actually performs the promotion of the specified
103221875Sadrian/// arguments, and returns the new function.  At this point, we know that it's
104221875Sadrian/// safe to do so.
105221875Sadrianstatic Function *
106189747SsamdoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote,
107189747Ssam            SmallPtrSetImpl<Argument *> &ByValArgsToTransform,
108189747Ssam            Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>>
109189747Ssam                ReplaceCallSite) {
110189747Ssam  // Start by computing a new prototype for the function, which is the same as
111217624Sadrian  // the old function, but has modified arguments.
112217624Sadrian  FunctionType *FTy = F->getFunctionType();
113189747Ssam  std::vector<Type *> Params;
114189747Ssam
115189747Ssam  using ScalarizeTable = std::set<std::pair<Type *, IndicesVector>>;
116189747Ssam
117189747Ssam  // ScalarizedElements - If we are promoting a pointer that has elements
118189747Ssam  // accessed out of it, keep track of which elements are accessed so that we
119189747Ssam  // can add one argument for each.
120219393Sadrian  //
121219441Sadrian  // Arguments that are directly loaded will have a zero element value here, to
122189747Ssam  // handle cases where there are both a direct load and GEP accesses.
123189747Ssam  std::map<Argument *, ScalarizeTable> ScalarizedElements;
124189747Ssam
125189747Ssam  // OriginalLoads - Keep track of a representative load instruction from the
126189747Ssam  // original function so that we can tell the alias analysis implementation
127189747Ssam  // what the new GEP/Load instructions we are inserting look like.
128189747Ssam  // We need to keep the original loads for each argument and the elements
129189747Ssam  // of the argument that are accessed.
130189747Ssam  std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads;
131189747Ssam
132189747Ssam  // Attribute - Keep track of the parameter attributes for the arguments
133189747Ssam  // that we are *not* promoting. For the ones that we do promote, the parameter
134189747Ssam  // attributes are lost
135189747Ssam  SmallVector<AttributeSet, 8> ArgAttrVec;
136189747Ssam  AttributeList PAL = F->getAttributes();
137189747Ssam
138189747Ssam  // First, determine the new argument list
139189747Ssam  unsigned ArgNo = 0;
140189747Ssam  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
141189747Ssam       ++I, ++ArgNo) {
142189747Ssam    if (ByValArgsToTransform.count(&*I)) {
143189747Ssam      // Simple byval argument? Just add all the struct element types.
144189747Ssam      Type *AgTy = cast<PointerType>(I->getType())->getElementType();
145189747Ssam      StructType *STy = cast<StructType>(AgTy);
146189747Ssam      Params.insert(Params.end(), STy->element_begin(), STy->element_end());
147189747Ssam      ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(),
148189747Ssam                        AttributeSet());
149189747Ssam      ++NumByValArgsPromoted;
150189747Ssam    } else if (!ArgsToPromote.count(&*I)) {
151189747Ssam      // Unchanged argument
152219393Sadrian      Params.push_back(I->getType());
153219393Sadrian      ArgAttrVec.push_back(PAL.getParamAttributes(ArgNo));
154219393Sadrian    } else if (I->use_empty()) {
155219393Sadrian      // Dead argument (which are always marked as promotable)
156189747Ssam      ++NumArgumentsDead;
157189747Ssam
158189747Ssam      // There may be remaining metadata uses of the argument for things like
159189747Ssam      // llvm.dbg.value. Replace them with undef.
160189747Ssam      I->replaceAllUsesWith(UndefValue::get(I->getType()));
161189747Ssam    } else {
162189747Ssam      // Okay, this is being promoted. This means that the only uses are loads
163189747Ssam      // or GEPs which are only used by loads
164189747Ssam
165189747Ssam      // In this table, we will track which indices are loaded from the argument
166189747Ssam      // (where direct loads are tracked as no indices).
167189747Ssam      ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
168189747Ssam      for (User *U : I->users()) {
169189747Ssam        Instruction *UI = cast<Instruction>(U);
170189747Ssam        Type *SrcTy;
171189747Ssam        if (LoadInst *L = dyn_cast<LoadInst>(UI))
172189747Ssam          SrcTy = L->getType();
173189747Ssam        else
174189747Ssam          SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType();
175189747Ssam        IndicesVector Indices;
176189747Ssam        Indices.reserve(UI->getNumOperands() - 1);
177189747Ssam        // Since loads will only have a single operand, and GEPs only a single
178189747Ssam        // non-index operand, this will record direct loads without any indices,
179189747Ssam        // and gep+loads with the GEP indices.
180189747Ssam        for (User::op_iterator II = UI->op_begin() + 1, IE = UI->op_end();
181189747Ssam             II != IE; ++II)
182189747Ssam          Indices.push_back(cast<ConstantInt>(*II)->getSExtValue());
183189747Ssam        // GEPs with a single 0 index can be merged with direct loads
184189747Ssam        if (Indices.size() == 1 && Indices.front() == 0)
185189747Ssam          Indices.clear();
186203882Srpaulo        ArgIndices.insert(std::make_pair(SrcTy, Indices));
187189747Ssam        LoadInst *OrigLoad;
188189747Ssam        if (LoadInst *L = dyn_cast<LoadInst>(UI))
189189747Ssam          OrigLoad = L;
190189747Ssam        else
191189747Ssam          // Take any load, we will use it only to update Alias Analysis
192189747Ssam          OrigLoad = cast<LoadInst>(UI->user_back());
193189747Ssam        OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad;
194189747Ssam      }
195189747Ssam
196189747Ssam      // Add a parameter to the function for each element passed in.
197189747Ssam      for (const auto &ArgIndex : ArgIndices) {
198189747Ssam        // not allowed to dereference ->begin() if size() is 0
199189747Ssam        Params.push_back(GetElementPtrInst::getIndexedType(
200189747Ssam            cast<PointerType>(I->getType()->getScalarType())->getElementType(),
201203882Srpaulo            ArgIndex.second));
202189747Ssam        ArgAttrVec.push_back(AttributeSet());
203189747Ssam        assert(Params.back());
204189747Ssam      }
205189747Ssam
206189747Ssam      if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty())
207189747Ssam        ++NumArgumentsPromoted;
208189747Ssam      else
209189747Ssam        ++NumAggregatesPromoted;
210189747Ssam    }
211189747Ssam  }
212189747Ssam
213189747Ssam  Type *RetTy = FTy->getReturnType();
214189747Ssam
215189747Ssam  // Construct the new function type using the new arguments.
216189747Ssam  FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
217189747Ssam
218189747Ssam  // Create the new function body and insert it into the module.
219189747Ssam  Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(),
220189747Ssam                                  F->getName());
221189747Ssam  NF->copyAttributesFrom(F);
222189747Ssam
223189747Ssam  // Patch the pointer to LLVM function in debug info descriptor.
224189747Ssam  NF->setSubprogram(F->getSubprogram());
225189747Ssam  F->setSubprogram(nullptr);
226189747Ssam
227189747Ssam  LLVM_DEBUG(dbgs() << "ARG PROMOTION:  Promoting to:" << *NF << "\n"
228189747Ssam                    << "From: " << *F);
229189747Ssam
230189747Ssam  // Recompute the parameter attributes list based on the new arguments for
231189747Ssam  // the function.
232189747Ssam  NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttributes(),
233189747Ssam                                       PAL.getRetAttributes(), ArgAttrVec));
234189747Ssam  ArgAttrVec.clear();
235189747Ssam
236189747Ssam  F->getParent()->getFunctionList().insert(F->getIterator(), NF);
237189747Ssam  NF->takeName(F);
238189747Ssam
239189747Ssam  // Loop over all of the callers of the function, transforming the call sites
240189747Ssam  // to pass in the loaded pointers.
241189747Ssam  //
242189747Ssam  SmallVector<Value *, 16> Args;
243189747Ssam  while (!F->use_empty()) {
244189747Ssam    CallSite CS(F->user_back());
245189747Ssam    assert(CS.getCalledFunction() == F);
246189747Ssam    Instruction *Call = CS.getInstruction();
247189747Ssam    const AttributeList &CallPAL = CS.getAttributes();
248189747Ssam    IRBuilder<NoFolder> IRB(Call);
249189747Ssam
250189747Ssam    // Loop over the operands, inserting GEP and loads in the caller as
251189747Ssam    // appropriate.
252189747Ssam    CallSite::arg_iterator AI = CS.arg_begin();
253189747Ssam    ArgNo = 0;
254219441Sadrian    for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
255219441Sadrian         ++I, ++AI, ++ArgNo)
256219441Sadrian      if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
257219441Sadrian        Args.push_back(*AI); // Unmodified argument
258219441Sadrian        ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo));
259219441Sadrian      } else if (ByValArgsToTransform.count(&*I)) {
260219441Sadrian        // Emit a GEP and load for each element of the struct.
261219441Sadrian        Type *AgTy = cast<PointerType>(I->getType())->getElementType();
262219441Sadrian        StructType *STy = cast<StructType>(AgTy);
263219441Sadrian        Value *Idxs[2] = {
264219441Sadrian            ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr};
265219441Sadrian        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
266219441Sadrian          Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
267219441Sadrian          auto *Idx =
268219393Sadrian              IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i));
269219393Sadrian          // TODO: Tell AA about the new values?
270219393Sadrian          Args.push_back(IRB.CreateLoad(STy->getElementType(i), Idx,
271219393Sadrian                                        Idx->getName() + ".val"));
272219393Sadrian          ArgAttrVec.push_back(AttributeSet());
273219445Sadrian        }
274219445Sadrian      } else if (!I->use_empty()) {
275219393Sadrian        // Non-dead argument: insert GEPs and loads as appropriate.
276219393Sadrian        ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
277221875Sadrian        // Store the Value* version of the indices in here, but declare it now
278221875Sadrian        // for reuse.
279189747Ssam        std::vector<Value *> Ops;
280189747Ssam        for (const auto &ArgIndex : ArgIndices) {
281189747Ssam          Value *V = *AI;
282189747Ssam          LoadInst *OrigLoad =
283189747Ssam              OriginalLoads[std::make_pair(&*I, ArgIndex.second)];
284189747Ssam          if (!ArgIndex.second.empty()) {
285189747Ssam            Ops.reserve(ArgIndex.second.size());
286189747Ssam            Type *ElTy = V->getType();
287189747Ssam            for (auto II : ArgIndex.second) {
288189747Ssam              // Use i32 to index structs, and i64 for others (pointers/arrays).
289189747Ssam              // This satisfies GEP constraints.
290189747Ssam              Type *IdxTy =
291189747Ssam                  (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext())
292189747Ssam                                      : Type::getInt64Ty(F->getContext()));
293189747Ssam              Ops.push_back(ConstantInt::get(IdxTy, II));
294189747Ssam              // Keep track of the type we're currently indexing.
295189747Ssam              if (auto *ElPTy = dyn_cast<PointerType>(ElTy))
296189747Ssam                ElTy = ElPTy->getElementType();
297189747Ssam              else
298221875Sadrian                ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II);
299221875Sadrian            }
300221875Sadrian            // And create a GEP to extract those indices.
301189747Ssam            V = IRB.CreateGEP(ArgIndex.first, V, Ops, V->getName() + ".idx");
302189747Ssam            Ops.clear();
303189747Ssam          }
304189747Ssam          // Since we're replacing a load make sure we take the alignment
305189747Ssam          // of the previous load.
306189747Ssam          LoadInst *newLoad =
307189747Ssam              IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val");
308189747Ssam          newLoad->setAlignment(MaybeAlign(OrigLoad->getAlignment()));
309189747Ssam          // Transfer the AA info too.
310189747Ssam          AAMDNodes AAInfo;
311189747Ssam          OrigLoad->getAAMetadata(AAInfo);
312189747Ssam          newLoad->setAAMetadata(AAInfo);
313189747Ssam
314189747Ssam          Args.push_back(newLoad);
315189747Ssam          ArgAttrVec.push_back(AttributeSet());
316189747Ssam        }
317189747Ssam      }
318189747Ssam
319189747Ssam    // Push any varargs arguments on the list.
320189747Ssam    for (; AI != CS.arg_end(); ++AI, ++ArgNo) {
321189747Ssam      Args.push_back(*AI);
322189747Ssam      ArgAttrVec.push_back(CallPAL.getParamAttributes(ArgNo));
323189747Ssam    }
324189747Ssam
325189747Ssam    SmallVector<OperandBundleDef, 1> OpBundles;
326189747Ssam    CS.getOperandBundlesAsDefs(OpBundles);
327189747Ssam
328189747Ssam    CallSite NewCS;
329189747Ssam    if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
330189747Ssam      NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
331189747Ssam                                 Args, OpBundles, "", Call);
332189747Ssam    } else {
333189747Ssam      auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", Call);
334189747Ssam      NewCall->setTailCallKind(cast<CallInst>(Call)->getTailCallKind());
335221596Sadrian      NewCS = NewCall;
336221596Sadrian    }
337189747Ssam    NewCS.setCallingConv(CS.getCallingConv());
338189747Ssam    NewCS.setAttributes(
339189747Ssam        AttributeList::get(F->getContext(), CallPAL.getFnAttributes(),
340189747Ssam                           CallPAL.getRetAttributes(), ArgAttrVec));
341189747Ssam    NewCS->setDebugLoc(Call->getDebugLoc());
342189747Ssam    uint64_t W;
343189747Ssam    if (Call->extractProfTotalWeight(W))
344189747Ssam      NewCS->setProfWeight(W);
345219852Sadrian    Args.clear();
346189747Ssam    ArgAttrVec.clear();
347189747Ssam
348218068Sadrian    // Update the callgraph to know that the callsite has been transformed.
349218068Sadrian    if (ReplaceCallSite)
350218068Sadrian      (*ReplaceCallSite)(CS, NewCS);
351218068Sadrian
352218068Sadrian    if (!Call->use_empty()) {
353218068Sadrian      Call->replaceAllUsesWith(NewCS.getInstruction());
354218068Sadrian      NewCS->takeName(Call);
355218068Sadrian    }
356218068Sadrian
357203882Srpaulo    // Finally, remove the old call from the program, reducing the use-count of
358189747Ssam    // F.
359189747Ssam    Call->eraseFromParent();
360189747Ssam  }
361189747Ssam
362189747Ssam  const DataLayout &DL = F->getParent()->getDataLayout();
363189747Ssam
364189747Ssam  // Since we have now created the new function, splice the body of the old
365189747Ssam  // function right into the new function, leaving the old rotting hulk of the
366189747Ssam  // function empty.
367189747Ssam  NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
368189747Ssam
369189747Ssam  // Loop over the argument list, transferring uses of the old arguments over to
370189747Ssam  // the new arguments, also transferring over the names as well.
371189747Ssam  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
372189747Ssam                              I2 = NF->arg_begin();
373189747Ssam       I != E; ++I) {
374189747Ssam    if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) {
375189747Ssam      // If this is an unmodified argument, move the name and users over to the
376189747Ssam      // new version.
377203882Srpaulo      I->replaceAllUsesWith(&*I2);
378189747Ssam      I2->takeName(&*I);
379189747Ssam      ++I2;
380189747Ssam      continue;
381189747Ssam    }
382189747Ssam
383189747Ssam    if (ByValArgsToTransform.count(&*I)) {
384189747Ssam      // In the callee, we create an alloca, and store each of the new incoming
385189747Ssam      // arguments into the alloca.
386219441Sadrian      Instruction *InsertPt = &NF->begin()->front();
387219441Sadrian
388189747Ssam      // Just add all the struct element types.
389189747Ssam      Type *AgTy = cast<PointerType>(I->getType())->getElementType();
390189747Ssam      Value *TheAlloca =
391189747Ssam          new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr,
392189747Ssam                         MaybeAlign(I->getParamAlignment()), "", InsertPt);
393189747Ssam      StructType *STy = cast<StructType>(AgTy);
394189747Ssam      Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0),
395189747Ssam                        nullptr};
396189747Ssam
397189747Ssam      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
398189747Ssam        Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
399189747Ssam        Value *Idx = GetElementPtrInst::Create(
400189747Ssam            AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i),
401189747Ssam            InsertPt);
402189747Ssam        I2->setName(I->getName() + "." + Twine(i));
403189747Ssam        new StoreInst(&*I2++, Idx, InsertPt);
404189747Ssam      }
405189747Ssam
406189747Ssam      // Anything that used the arg should now use the alloca.
407189747Ssam      I->replaceAllUsesWith(TheAlloca);
408189747Ssam      TheAlloca->takeName(&*I);
409189747Ssam
410189747Ssam      // If the alloca is used in a call, we must clear the tail flag since
411189747Ssam      // the callee now uses an alloca from the caller.
412219441Sadrian      for (User *U : TheAlloca->users()) {
413219441Sadrian        CallInst *Call = dyn_cast<CallInst>(U);
414219441Sadrian        if (!Call)
415219441Sadrian          continue;
416219441Sadrian        Call->setTailCall(false);
417219441Sadrian      }
418219441Sadrian      continue;
419189747Ssam    }
420189747Ssam
421219441Sadrian    if (I->use_empty())
422219441Sadrian      continue;
423219441Sadrian
424219441Sadrian    // Otherwise, if we promoted this argument, then all users are load
425219441Sadrian    // instructions (or GEPs with only load users), and all loads should be
426219441Sadrian    // using the new argument that we added.
427219441Sadrian    ScalarizeTable &ArgIndices = ScalarizedElements[&*I];
428219441Sadrian
429219441Sadrian    while (!I->use_empty()) {
430219441Sadrian      if (LoadInst *LI = dyn_cast<LoadInst>(I->user_back())) {
431219441Sadrian        assert(ArgIndices.begin()->second.empty() &&
432219441Sadrian               "Load element should sort to front!");
433219441Sadrian        I2->setName(I->getName() + ".val");
434219441Sadrian        LI->replaceAllUsesWith(&*I2);
435219441Sadrian        LI->eraseFromParent();
436219441Sadrian        LLVM_DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName()
437219441Sadrian                          << "' in function '" << F->getName() << "'\n");
438219441Sadrian      } else {
439189747Ssam        GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back());
440189747Ssam        IndicesVector Operands;
441189747Ssam        Operands.reserve(GEP->getNumIndices());
442189747Ssam        for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
443189747Ssam             II != IE; ++II)
444189747Ssam          Operands.push_back(cast<ConstantInt>(*II)->getSExtValue());
445189747Ssam
446189747Ssam        // GEPs with a single 0 index can be merged with direct loads
447189747Ssam        if (Operands.size() == 1 && Operands.front() == 0)
448189747Ssam          Operands.clear();
449189747Ssam
450189747Ssam        Function::arg_iterator TheArg = I2;
451189747Ssam        for (ScalarizeTable::iterator It = ArgIndices.begin();
452189747Ssam             It->second != Operands; ++It, ++TheArg) {
453189747Ssam          assert(It != ArgIndices.end() && "GEP not handled??");
454189747Ssam        }
455189747Ssam
456189747Ssam        std::string NewName = I->getName();
457189747Ssam        for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
458189747Ssam          NewName += "." + utostr(Operands[i]);
459189747Ssam        }
460189747Ssam        NewName += ".val";
461203930Srpaulo        TheArg->setName(NewName);
462189747Ssam
463189747Ssam        LLVM_DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName()
464189747Ssam                          << "' of function '" << NF->getName() << "'\n");
465189747Ssam
466189747Ssam        // All of the uses must be load instructions.  Replace them all with
467189747Ssam        // the argument specified by ArgNo.
468189747Ssam        while (!GEP->use_empty()) {
469189747Ssam          LoadInst *L = cast<LoadInst>(GEP->user_back());
470189747Ssam          L->replaceAllUsesWith(&*TheArg);
471189747Ssam          L->eraseFromParent();
472189747Ssam        }
473189747Ssam        GEP->eraseFromParent();
474189747Ssam      }
475189747Ssam    }
476189747Ssam
477189747Ssam    // Increment I2 past all of the arguments added for this promoted pointer.
478189747Ssam    std::advance(I2, ArgIndices.size());
479189747Ssam  }
480189747Ssam
481189747Ssam  return NF;
482189747Ssam}
483189747Ssam
484189747Ssam/// Return true if we can prove that all callees pass in a valid pointer for the
485189747Ssam/// specified function argument.
486189747Ssamstatic bool allCallersPassValidPointerForArgument(Argument *Arg, Type *Ty) {
487189747Ssam  Function *Callee = Arg->getParent();
488189747Ssam  const DataLayout &DL = Callee->getParent()->getDataLayout();
489189747Ssam
490189747Ssam  unsigned ArgNo = Arg->getArgNo();
491189747Ssam
492189747Ssam  // Look at all call sites of the function.  At this point we know we only have
493189747Ssam  // direct callees.
494189747Ssam  for (User *U : Callee->users()) {
495189747Ssam    CallSite CS(U);
496189747Ssam    assert(CS && "Should only have direct calls!");
497189747Ssam
498189747Ssam    if (!isDereferenceablePointer(CS.getArgument(ArgNo), Ty, DL))
499189747Ssam      return false;
500189747Ssam  }
501189747Ssam  return true;
502189747Ssam}
503189747Ssam
504189747Ssam/// Returns true if Prefix is a prefix of longer. That means, Longer has a size
505189747Ssam/// that is greater than or equal to the size of prefix, and each of the
506189747Ssam/// elements in Prefix is the same as the corresponding elements in Longer.
507189747Ssam///
508189747Ssam/// This means it also returns true when Prefix and Longer are equal!
509189747Ssamstatic bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) {
510189747Ssam  if (Prefix.size() > Longer.size())
511189747Ssam    return false;
512189747Ssam  return std::equal(Prefix.begin(), Prefix.end(), Longer.begin());
513189747Ssam}
514189747Ssam
515189747Ssam/// Checks if Indices, or a prefix of Indices, is in Set.
516189747Ssamstatic bool prefixIn(const IndicesVector &Indices,
517189747Ssam                     std::set<IndicesVector> &Set) {
518189747Ssam  std::set<IndicesVector>::iterator Low;
519189747Ssam  Low = Set.upper_bound(Indices);
520189747Ssam  if (Low != Set.begin())
521189747Ssam    Low--;
522189747Ssam  // Low is now the last element smaller than or equal to Indices. This means
523189747Ssam  // it points to a prefix of Indices (possibly Indices itself), if such
524189747Ssam  // prefix exists.
525189747Ssam  //
526189747Ssam  // This load is safe if any prefix of its operands is safe to load.
527189747Ssam  return Low != Set.end() && isPrefix(*Low, Indices);
528189747Ssam}
529189747Ssam
530189747Ssam/// Mark the given indices (ToMark) as safe in the given set of indices
531189747Ssam/// (Safe). Marking safe usually means adding ToMark to Safe. However, if there
532189747Ssam/// is already a prefix of Indices in Safe, Indices are implicitely marked safe
533189747Ssam/// already. Furthermore, any indices that Indices is itself a prefix of, are
534189747Ssam/// removed from Safe (since they are implicitely safe because of Indices now).
535189747Ssamstatic void markIndicesSafe(const IndicesVector &ToMark,
536189747Ssam                            std::set<IndicesVector> &Safe) {
537189747Ssam  std::set<IndicesVector>::iterator Low;
538189747Ssam  Low = Safe.upper_bound(ToMark);
539189747Ssam  // Guard against the case where Safe is empty
540189747Ssam  if (Low != Safe.begin())
541189747Ssam    Low--;
542189747Ssam  // Low is now the last element smaller than or equal to Indices. This
543189747Ssam  // means it points to a prefix of Indices (possibly Indices itself), if
544189747Ssam  // such prefix exists.
545189747Ssam  if (Low != Safe.end()) {
546189747Ssam    if (isPrefix(*Low, ToMark))
547189747Ssam      // If there is already a prefix of these indices (or exactly these
548189747Ssam      // indices) marked a safe, don't bother adding these indices
549189747Ssam      return;
550189747Ssam
551189747Ssam    // Increment Low, so we can use it as a "insert before" hint
552189747Ssam    ++Low;
553189747Ssam  }
554189747Ssam  // Insert
555189747Ssam  Low = Safe.insert(Low, ToMark);
556189747Ssam  ++Low;
557189747Ssam  // If there we're a prefix of longer index list(s), remove those
558189747Ssam  std::set<IndicesVector>::iterator End = Safe.end();
559189747Ssam  while (Low != End && isPrefix(ToMark, *Low)) {
560189747Ssam    std::set<IndicesVector>::iterator Remove = Low;
561189747Ssam    ++Low;
562189747Ssam    Safe.erase(Remove);
563189747Ssam  }
564189747Ssam}
565189747Ssam
566189747Ssam/// isSafeToPromoteArgument - As you might guess from the name of this method,
567189747Ssam/// it checks to see if it is both safe and useful to promote the argument.
568189747Ssam/// This method limits promotion of aggregates to only promote up to three
569189747Ssam/// elements of the aggregate in order to avoid exploding the number of
570189747Ssam/// arguments passed in.
571189747Ssamstatic bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR,
572189747Ssam                                    unsigned MaxElements) {
573189747Ssam  using GEPIndicesSet = std::set<IndicesVector>;
574189747Ssam
575189747Ssam  // Quick exit for unused arguments
576189747Ssam  if (Arg->use_empty())
577189747Ssam    return true;
578189747Ssam
579189747Ssam  // We can only promote this argument if all of the uses are loads, or are GEP
580189747Ssam  // instructions (with constant indices) that are subsequently loaded.
581189747Ssam  //
582189747Ssam  // Promoting the argument causes it to be loaded in the caller
583189747Ssam  // unconditionally. This is only safe if we can prove that either the load
584189747Ssam  // would have happened in the callee anyway (ie, there is a load in the entry
585189747Ssam  // block) or the pointer passed in at every call site is guaranteed to be
586189747Ssam  // valid.
587189747Ssam  // In the former case, invalid loads can happen, but would have happened
588189747Ssam  // anyway, in the latter case, invalid loads won't happen. This prevents us
589189747Ssam  // from introducing an invalid load that wouldn't have happened in the
590189747Ssam  // original code.
591189747Ssam  //
592189747Ssam  // This set will contain all sets of indices that are loaded in the entry
593189747Ssam  // block, and thus are safe to unconditionally load in the caller.
594189747Ssam  GEPIndicesSet SafeToUnconditionallyLoad;
595189747Ssam
596189747Ssam  // This set contains all the sets of indices that we are planning to promote.
597189747Ssam  // This makes it possible to limit the number of arguments added.
598189747Ssam  GEPIndicesSet ToPromote;
599189747Ssam
600189747Ssam  // If the pointer is always valid, any load with first index 0 is valid.
601189747Ssam
602189747Ssam  if (ByValTy)
603189747Ssam    SafeToUnconditionallyLoad.insert(IndicesVector(1, 0));
604189747Ssam
605189747Ssam  // Whenever a new underlying type for the operand is found, make sure it's
606189747Ssam  // consistent with the GEPs and loads we've already seen and, if necessary,
607189747Ssam  // use it to see if all incoming pointers are valid (which implies the 0-index
608189747Ssam  // is safe).
609189747Ssam  Type *BaseTy = ByValTy;
610189747Ssam  auto UpdateBaseTy = [&](Type *NewBaseTy) {
611189747Ssam    if (BaseTy)
612189747Ssam      return BaseTy == NewBaseTy;
613189747Ssam
614189747Ssam    BaseTy = NewBaseTy;
615189747Ssam    if (allCallersPassValidPointerForArgument(Arg, BaseTy)) {
616189747Ssam      assert(SafeToUnconditionallyLoad.empty());
617189747Ssam      SafeToUnconditionallyLoad.insert(IndicesVector(1, 0));
618189747Ssam    }
619189747Ssam
620189747Ssam    return true;
621189747Ssam  };
622189747Ssam
623189747Ssam  // First, iterate the entry block and mark loads of (geps of) arguments as
624189747Ssam  // safe.
625189747Ssam  BasicBlock &EntryBlock = Arg->getParent()->front();
626189747Ssam  // Declare this here so we can reuse it
627189747Ssam  IndicesVector Indices;
628189747Ssam  for (Instruction &I : EntryBlock)
629189747Ssam    if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
630189747Ssam      Value *V = LI->getPointerOperand();
631189747Ssam      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
632189747Ssam        V = GEP->getPointerOperand();
633189747Ssam        if (V == Arg) {
634189747Ssam          // This load actually loads (part of) Arg? Check the indices then.
635189747Ssam          Indices.reserve(GEP->getNumIndices());
636189747Ssam          for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
637189747Ssam               II != IE; ++II)
638189747Ssam            if (ConstantInt *CI = dyn_cast<ConstantInt>(*II))
639189747Ssam              Indices.push_back(CI->getSExtValue());
640189747Ssam            else
641189747Ssam              // We found a non-constant GEP index for this argument? Bail out
642189747Ssam              // right away, can't promote this argument at all.
643189747Ssam              return false;
644189747Ssam
645189747Ssam          if (!UpdateBaseTy(GEP->getSourceElementType()))
646189747Ssam            return false;
647189747Ssam
648189747Ssam          // Indices checked out, mark them as safe
649189747Ssam          markIndicesSafe(Indices, SafeToUnconditionallyLoad);
650189747Ssam          Indices.clear();
651189747Ssam        }
652189747Ssam      } else if (V == Arg) {
653189747Ssam        // Direct loads are equivalent to a GEP with a single 0 index.
654189747Ssam        markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad);
655189747Ssam
656189747Ssam        if (BaseTy && LI->getType() != BaseTy)
657189747Ssam          return false;
658189747Ssam
659189747Ssam        BaseTy = LI->getType();
660189747Ssam      }
661189747Ssam    }
662189747Ssam
663189747Ssam  // Now, iterate all uses of the argument to see if there are any uses that are
664189747Ssam  // not (GEP+)loads, or any (GEP+)loads that are not safe to promote.
665189747Ssam  SmallVector<LoadInst *, 16> Loads;
666189747Ssam  IndicesVector Operands;
667189747Ssam  for (Use &U : Arg->uses()) {
668189747Ssam    User *UR = U.getUser();
669189747Ssam    Operands.clear();
670189747Ssam    if (LoadInst *LI = dyn_cast<LoadInst>(UR)) {
671189747Ssam      // Don't hack volatile/atomic loads
672189747Ssam      if (!LI->isSimple())
673189747Ssam        return false;
674189747Ssam      Loads.push_back(LI);
675189747Ssam      // Direct loads are equivalent to a GEP with a zero index and then a load.
676189747Ssam      Operands.push_back(0);
677189747Ssam
678189747Ssam      if (!UpdateBaseTy(LI->getType()))
679189747Ssam        return false;
680189747Ssam    } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UR)) {
681189747Ssam      if (GEP->use_empty()) {
682189747Ssam        // Dead GEP's cause trouble later.  Just remove them if we run into
683189747Ssam        // them.
684189747Ssam        GEP->eraseFromParent();
685189747Ssam        // TODO: This runs the above loop over and over again for dead GEPs
686189747Ssam        // Couldn't we just do increment the UI iterator earlier and erase the
687189747Ssam        // use?
688189747Ssam        return isSafeToPromoteArgument(Arg, ByValTy, AAR, MaxElements);
689189747Ssam      }
690189747Ssam
691189747Ssam      if (!UpdateBaseTy(GEP->getSourceElementType()))
692189747Ssam        return false;
693189747Ssam
694189747Ssam      // Ensure that all of the indices are constants.
695189747Ssam      for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end(); i != e;
696189747Ssam           ++i)
697189747Ssam        if (ConstantInt *C = dyn_cast<ConstantInt>(*i))
698189747Ssam          Operands.push_back(C->getSExtValue());
699189747Ssam        else
700189747Ssam          return false; // Not a constant operand GEP!
701189747Ssam
702189747Ssam      // Ensure that the only users of the GEP are load instructions.
703189747Ssam      for (User *GEPU : GEP->users())
704189747Ssam        if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) {
705189747Ssam          // Don't hack volatile/atomic loads
706189747Ssam          if (!LI->isSimple())
707189747Ssam            return false;
708189747Ssam          Loads.push_back(LI);
709189747Ssam        } else {
710189747Ssam          // Other uses than load?
711189747Ssam          return false;
712189747Ssam        }
713189747Ssam    } else {
714189747Ssam      return false; // Not a load or a GEP.
715189747Ssam    }
716189747Ssam
717189747Ssam    // Now, see if it is safe to promote this load / loads of this GEP. Loading
718189747Ssam    // is safe if Operands, or a prefix of Operands, is marked as safe.
719189747Ssam    if (!prefixIn(Operands, SafeToUnconditionallyLoad))
720189747Ssam      return false;
721189747Ssam
722189747Ssam    // See if we are already promoting a load with these indices. If not, check
723189747Ssam    // to make sure that we aren't promoting too many elements.  If so, nothing
724189747Ssam    // to do.
725189747Ssam    if (ToPromote.find(Operands) == ToPromote.end()) {
726189747Ssam      if (MaxElements > 0 && ToPromote.size() == MaxElements) {
727189747Ssam        LLVM_DEBUG(dbgs() << "argpromotion not promoting argument '"
728189747Ssam                          << Arg->getName()
729189747Ssam                          << "' because it would require adding more "
730189747Ssam                          << "than " << MaxElements
731189747Ssam                          << " arguments to the function.\n");
732189747Ssam        // We limit aggregate promotion to only promoting up to a fixed number
733189747Ssam        // of elements of the aggregate.
734189747Ssam        return false;
735189747Ssam      }
736189747Ssam      ToPromote.insert(std::move(Operands));
737189747Ssam    }
738189747Ssam  }
739189747Ssam
740189747Ssam  if (Loads.empty())
741189747Ssam    return true; // No users, this is a dead argument.
742189747Ssam
743189747Ssam  // Okay, now we know that the argument is only used by load instructions and
744189747Ssam  // it is safe to unconditionally perform all of them. Use alias analysis to
745189747Ssam  // check to see if the pointer is guaranteed to not be modified from entry of
746189747Ssam  // the function to each of the load instructions.
747189747Ssam
748189747Ssam  // Because there could be several/many load instructions, remember which
749189747Ssam  // blocks we know to be transparent to the load.
750189747Ssam  df_iterator_default_set<BasicBlock *, 16> TranspBlocks;
751189747Ssam
752189747Ssam  for (LoadInst *Load : Loads) {
753189747Ssam    // Check to see if the load is invalidated from the start of the block to
754189747Ssam    // the load itself.
755189747Ssam    BasicBlock *BB = Load->getParent();
756189747Ssam
757189747Ssam    MemoryLocation Loc = MemoryLocation::get(Load);
758189747Ssam    if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod))
759189747Ssam      return false; // Pointer is invalidated!
760189747Ssam
761189747Ssam    // Now check every path from the entry block to the load for transparency.
762189747Ssam    // To do this, we perform a depth first search on the inverse CFG from the
763189747Ssam    // loading block.
764189747Ssam    for (BasicBlock *P : predecessors(BB)) {
765189747Ssam      for (BasicBlock *TranspBB : inverse_depth_first_ext(P, TranspBlocks))
766189747Ssam        if (AAR.canBasicBlockModify(*TranspBB, Loc))
767189747Ssam          return false;
768189747Ssam    }
769189747Ssam  }
770189747Ssam
771203882Srpaulo  // If the path from the entry of the function to each load is free of
772189747Ssam  // instructions that potentially invalidate the load, we can make the
773189747Ssam  // transformation!
774189747Ssam  return true;
775189747Ssam}
776189747Ssam
777218150Sadrian/// Checks if a type could have padding bytes.
778218150Sadrianstatic bool isDenselyPacked(Type *type, const DataLayout &DL) {
779218150Sadrian  // There is no size information, so be conservative.
780218150Sadrian  if (!type->isSized())
781189747Ssam    return false;
782189747Ssam
783189747Ssam  // If the alloc size is not equal to the storage size, then there are padding
784189747Ssam  // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128.
785189747Ssam  if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type))
786189747Ssam    return false;
787189747Ssam
788189747Ssam  if (!isa<CompositeType>(type))
789189747Ssam    return true;
790189747Ssam
791189747Ssam  // For homogenous sequential types, check for padding within members.
792220325Sadrian  if (SequentialType *seqTy = dyn_cast<SequentialType>(type))
793220325Sadrian    return isDenselyPacked(seqTy->getElementType(), DL);
794221603Sadrian
795221603Sadrian  // Check for padding within and between elements of a struct.
796221603Sadrian  StructType *StructTy = cast<StructType>(type);
797221667Sadrian  const StructLayout *Layout = DL.getStructLayout(StructTy);
798221603Sadrian  uint64_t StartPos = 0;
799221667Sadrian  for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) {
800221667Sadrian    Type *ElTy = StructTy->getElementType(i);
801221667Sadrian    if (!isDenselyPacked(ElTy, DL))
802221667Sadrian      return false;
803221667Sadrian    if (StartPos != Layout->getElementOffsetInBits(i))
804221667Sadrian      return false;
805189747Ssam    StartPos += DL.getTypeAllocSizeInBits(ElTy);
806189747Ssam  }
807189747Ssam
808189747Ssam  return true;
809189747Ssam}
810189747Ssam
811218708Sadrian/// Checks if the padding bytes of an argument could be accessed.
812218708Sadrianstatic bool canPaddingBeAccessed(Argument *arg) {
813218708Sadrian  assert(arg->hasByValAttr());
814218708Sadrian
815218708Sadrian  // Track all the pointers to the argument to make sure they are not captured.
816218708Sadrian  SmallPtrSet<Value *, 16> PtrValues;
817189747Ssam  PtrValues.insert(arg);
818189747Ssam
819189747Ssam  // Track all of the stores.
820189747Ssam  SmallVector<StoreInst *, 16> Stores;
821189747Ssam
822218708Sadrian  // Scan through the uses recursively to make sure the pointer is always used
823189747Ssam  // sanely.
824189747Ssam  SmallVector<Value *, 16> WorkList;
825189747Ssam  WorkList.insert(WorkList.end(), arg->user_begin(), arg->user_end());
826189747Ssam  while (!WorkList.empty()) {
827189747Ssam    Value *V = WorkList.back();
828189747Ssam    WorkList.pop_back();
829189747Ssam    if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) {
830189747Ssam      if (PtrValues.insert(V).second)
831189747Ssam        WorkList.insert(WorkList.end(), V->user_begin(), V->user_end());
832189747Ssam    } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) {
833189747Ssam      Stores.push_back(Store);
834189747Ssam    } else if (!isa<LoadInst>(V)) {
835189747Ssam      return true;
836189747Ssam    }
837189747Ssam  }
838189747Ssam
839189747Ssam  // Check to make sure the pointers aren't captured
840189747Ssam  for (StoreInst *Store : Stores)
841217641Sadrian    if (PtrValues.count(Store->getValueOperand()))
842217641Sadrian      return true;
843189747Ssam
844189747Ssam  return false;
845217684Sadrian}
846217684Sadrian
847217684Sadrianstatic bool areFunctionArgsABICompatible(
848217684Sadrian    const Function &F, const TargetTransformInfo &TTI,
849218708Sadrian    SmallPtrSetImpl<Argument *> &ArgsToPromote,
850189747Ssam    SmallPtrSetImpl<Argument *> &ByValArgsToTransform) {
851189747Ssam  for (const Use &U : F.uses()) {
852189747Ssam    CallSite CS(U.getUser());
853189747Ssam    const Function *Caller = CS.getCaller();
854189747Ssam    const Function *Callee = CS.getCalledFunction();
855189747Ssam    if (!TTI.areFunctionArgsABICompatible(Caller, Callee, ArgsToPromote) ||
856189747Ssam        !TTI.areFunctionArgsABICompatible(Caller, Callee, ByValArgsToTransform))
857189747Ssam      return false;
858203882Srpaulo  }
859203882Srpaulo  return true;
860203882Srpaulo}
861189747Ssam
862189747Ssam/// PromoteArguments - This method checks the specified function to see if there
863189747Ssam/// are any promotable arguments and if it is safe to promote the function (for
864/// example, all callers are direct).  If safe to promote some arguments, it
865/// calls the DoPromotion method.
866static Function *
867promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter,
868                 unsigned MaxElements,
869                 Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>>
870                     ReplaceCallSite,
871                 const TargetTransformInfo &TTI) {
872  // Don't perform argument promotion for naked functions; otherwise we can end
873  // up removing parameters that are seemingly 'not used' as they are referred
874  // to in the assembly.
875  if(F->hasFnAttribute(Attribute::Naked))
876    return nullptr;
877
878  // Make sure that it is local to this module.
879  if (!F->hasLocalLinkage())
880    return nullptr;
881
882  // Don't promote arguments for variadic functions. Adding, removing, or
883  // changing non-pack parameters can change the classification of pack
884  // parameters. Frontends encode that classification at the call site in the
885  // IR, while in the callee the classification is determined dynamically based
886  // on the number of registers consumed so far.
887  if (F->isVarArg())
888    return nullptr;
889
890  // Don't transform functions that receive inallocas, as the transformation may
891  // not be safe depending on calling convention.
892  if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca))
893    return nullptr;
894
895  // First check: see if there are any pointer arguments!  If not, quick exit.
896  SmallVector<Argument *, 16> PointerArgs;
897  for (Argument &I : F->args())
898    if (I.getType()->isPointerTy())
899      PointerArgs.push_back(&I);
900  if (PointerArgs.empty())
901    return nullptr;
902
903  // Second check: make sure that all callers are direct callers.  We can't
904  // transform functions that have indirect callers.  Also see if the function
905  // is self-recursive and check that target features are compatible.
906  bool isSelfRecursive = false;
907  for (Use &U : F->uses()) {
908    CallSite CS(U.getUser());
909    // Must be a direct call.
910    if (CS.getInstruction() == nullptr || !CS.isCallee(&U))
911      return nullptr;
912
913    // Can't change signature of musttail callee
914    if (CS.isMustTailCall())
915      return nullptr;
916
917    if (CS.getInstruction()->getParent()->getParent() == F)
918      isSelfRecursive = true;
919  }
920
921  // Can't change signature of musttail caller
922  // FIXME: Support promoting whole chain of musttail functions
923  for (BasicBlock &BB : *F)
924    if (BB.getTerminatingMustTailCall())
925      return nullptr;
926
927  const DataLayout &DL = F->getParent()->getDataLayout();
928
929  AAResults &AAR = AARGetter(*F);
930
931  // Check to see which arguments are promotable.  If an argument is promotable,
932  // add it to ArgsToPromote.
933  SmallPtrSet<Argument *, 8> ArgsToPromote;
934  SmallPtrSet<Argument *, 8> ByValArgsToTransform;
935  for (Argument *PtrArg : PointerArgs) {
936    Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType();
937
938    // Replace sret attribute with noalias. This reduces register pressure by
939    // avoiding a register copy.
940    if (PtrArg->hasStructRetAttr()) {
941      unsigned ArgNo = PtrArg->getArgNo();
942      F->removeParamAttr(ArgNo, Attribute::StructRet);
943      F->addParamAttr(ArgNo, Attribute::NoAlias);
944      for (Use &U : F->uses()) {
945        CallSite CS(U.getUser());
946        CS.removeParamAttr(ArgNo, Attribute::StructRet);
947        CS.addParamAttr(ArgNo, Attribute::NoAlias);
948      }
949    }
950
951    // If this is a byval argument, and if the aggregate type is small, just
952    // pass the elements, which is always safe, if the passed value is densely
953    // packed or if we can prove the padding bytes are never accessed.
954    bool isSafeToPromote =
955        PtrArg->hasByValAttr() &&
956        (isDenselyPacked(AgTy, DL) || !canPaddingBeAccessed(PtrArg));
957    if (isSafeToPromote) {
958      if (StructType *STy = dyn_cast<StructType>(AgTy)) {
959        if (MaxElements > 0 && STy->getNumElements() > MaxElements) {
960          LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '"
961                            << PtrArg->getName()
962                            << "' because it would require adding more"
963                            << " than " << MaxElements
964                            << " arguments to the function.\n");
965          continue;
966        }
967
968        // If all the elements are single-value types, we can promote it.
969        bool AllSimple = true;
970        for (const auto *EltTy : STy->elements()) {
971          if (!EltTy->isSingleValueType()) {
972            AllSimple = false;
973            break;
974          }
975        }
976
977        // Safe to transform, don't even bother trying to "promote" it.
978        // Passing the elements as a scalar will allow sroa to hack on
979        // the new alloca we introduce.
980        if (AllSimple) {
981          ByValArgsToTransform.insert(PtrArg);
982          continue;
983        }
984      }
985    }
986
987    // If the argument is a recursive type and we're in a recursive
988    // function, we could end up infinitely peeling the function argument.
989    if (isSelfRecursive) {
990      if (StructType *STy = dyn_cast<StructType>(AgTy)) {
991        bool RecursiveType = false;
992        for (const auto *EltTy : STy->elements()) {
993          if (EltTy == PtrArg->getType()) {
994            RecursiveType = true;
995            break;
996          }
997        }
998        if (RecursiveType)
999          continue;
1000      }
1001    }
1002
1003    // Otherwise, see if we can promote the pointer to its value.
1004    Type *ByValTy =
1005        PtrArg->hasByValAttr() ? PtrArg->getParamByValType() : nullptr;
1006    if (isSafeToPromoteArgument(PtrArg, ByValTy, AAR, MaxElements))
1007      ArgsToPromote.insert(PtrArg);
1008  }
1009
1010  // No promotable pointer arguments.
1011  if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
1012    return nullptr;
1013
1014  if (!areFunctionArgsABICompatible(*F, TTI, ArgsToPromote,
1015                                    ByValArgsToTransform))
1016    return nullptr;
1017
1018  return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite);
1019}
1020
1021PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C,
1022                                             CGSCCAnalysisManager &AM,
1023                                             LazyCallGraph &CG,
1024                                             CGSCCUpdateResult &UR) {
1025  bool Changed = false, LocalChange;
1026
1027  // Iterate until we stop promoting from this SCC.
1028  do {
1029    LocalChange = false;
1030
1031    for (LazyCallGraph::Node &N : C) {
1032      Function &OldF = N.getFunction();
1033
1034      FunctionAnalysisManager &FAM =
1035          AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
1036      // FIXME: This lambda must only be used with this function. We should
1037      // skip the lambda and just get the AA results directly.
1038      auto AARGetter = [&](Function &F) -> AAResults & {
1039        assert(&F == &OldF && "Called with an unexpected function!");
1040        return FAM.getResult<AAManager>(F);
1041      };
1042
1043      const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF);
1044      Function *NewF =
1045          promoteArguments(&OldF, AARGetter, MaxElements, None, TTI);
1046      if (!NewF)
1047        continue;
1048      LocalChange = true;
1049
1050      // Directly substitute the functions in the call graph. Note that this
1051      // requires the old function to be completely dead and completely
1052      // replaced by the new function. It does no call graph updates, it merely
1053      // swaps out the particular function mapped to a particular node in the
1054      // graph.
1055      C.getOuterRefSCC().replaceNodeFunction(N, *NewF);
1056      OldF.eraseFromParent();
1057    }
1058
1059    Changed |= LocalChange;
1060  } while (LocalChange);
1061
1062  if (!Changed)
1063    return PreservedAnalyses::all();
1064
1065  return PreservedAnalyses::none();
1066}
1067
1068namespace {
1069
1070/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass.
1071struct ArgPromotion : public CallGraphSCCPass {
1072  // Pass identification, replacement for typeid
1073  static char ID;
1074
1075  explicit ArgPromotion(unsigned MaxElements = 3)
1076      : CallGraphSCCPass(ID), MaxElements(MaxElements) {
1077    initializeArgPromotionPass(*PassRegistry::getPassRegistry());
1078  }
1079
1080  void getAnalysisUsage(AnalysisUsage &AU) const override {
1081    AU.addRequired<AssumptionCacheTracker>();
1082    AU.addRequired<TargetLibraryInfoWrapperPass>();
1083    AU.addRequired<TargetTransformInfoWrapperPass>();
1084    getAAResultsAnalysisUsage(AU);
1085    CallGraphSCCPass::getAnalysisUsage(AU);
1086  }
1087
1088  bool runOnSCC(CallGraphSCC &SCC) override;
1089
1090private:
1091  using llvm::Pass::doInitialization;
1092
1093  bool doInitialization(CallGraph &CG) override;
1094
1095  /// The maximum number of elements to expand, or 0 for unlimited.
1096  unsigned MaxElements;
1097};
1098
1099} // end anonymous namespace
1100
1101char ArgPromotion::ID = 0;
1102
1103INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
1104                      "Promote 'by reference' arguments to scalars", false,
1105                      false)
1106INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
1107INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1108INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1109INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
1110INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
1111                    "Promote 'by reference' arguments to scalars", false, false)
1112
1113Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) {
1114  return new ArgPromotion(MaxElements);
1115}
1116
1117bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
1118  if (skipSCC(SCC))
1119    return false;
1120
1121  // Get the callgraph information that we need to update to reflect our
1122  // changes.
1123  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1124
1125  LegacyAARGetter AARGetter(*this);
1126
1127  bool Changed = false, LocalChange;
1128
1129  // Iterate until we stop promoting from this SCC.
1130  do {
1131    LocalChange = false;
1132    // Attempt to promote arguments from all functions in this SCC.
1133    for (CallGraphNode *OldNode : SCC) {
1134      Function *OldF = OldNode->getFunction();
1135      if (!OldF)
1136        continue;
1137
1138      auto ReplaceCallSite = [&](CallSite OldCS, CallSite NewCS) {
1139        Function *Caller = OldCS.getInstruction()->getParent()->getParent();
1140        CallGraphNode *NewCalleeNode =
1141            CG.getOrInsertFunction(NewCS.getCalledFunction());
1142        CallGraphNode *CallerNode = CG[Caller];
1143        CallerNode->replaceCallEdge(*cast<CallBase>(OldCS.getInstruction()),
1144                                    *cast<CallBase>(NewCS.getInstruction()),
1145                                    NewCalleeNode);
1146      };
1147
1148      const TargetTransformInfo &TTI =
1149          getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF);
1150      if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements,
1151                                            {ReplaceCallSite}, TTI)) {
1152        LocalChange = true;
1153
1154        // Update the call graph for the newly promoted function.
1155        CallGraphNode *NewNode = CG.getOrInsertFunction(NewF);
1156        NewNode->stealCalledFunctionsFrom(OldNode);
1157        if (OldNode->getNumReferences() == 0)
1158          delete CG.removeFunctionFromModule(OldNode);
1159        else
1160          OldF->setLinkage(Function::ExternalLinkage);
1161
1162        // And updat ethe SCC we're iterating as well.
1163        SCC.ReplaceNode(OldNode, NewNode);
1164      }
1165    }
1166    // Remember that we changed something.
1167    Changed |= LocalChange;
1168  } while (LocalChange);
1169
1170  return Changed;
1171}
1172
1173bool ArgPromotion::doInitialization(CallGraph &CG) {
1174  return CallGraphSCCPass::doInitialization(CG);
1175}
1176