1//===- InferAddressSpace.cpp - --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// CUDA C/C++ includes memory space designation as variable type qualifers (such
10// as __global__ and __shared__). Knowing the space of a memory access allows
11// CUDA compilers to emit faster PTX loads and stores. For example, a load from
12// shared memory can be translated to `ld.shared` which is roughly 10% faster
13// than a generic `ld` on an NVIDIA Tesla K40c.
14//
15// Unfortunately, type qualifiers only apply to variable declarations, so CUDA
16// compilers must infer the memory space of an address expression from
17// type-qualified variables.
18//
19// LLVM IR uses non-zero (so-called) specific address spaces to represent memory
20// spaces (e.g. addrspace(3) means shared memory). The Clang frontend
21// places only type-qualified variables in specific address spaces, and then
22// conservatively `addrspacecast`s each type-qualified variable to addrspace(0)
23// (so-called the generic address space) for other instructions to use.
24//
25// For example, the Clang translates the following CUDA code
26//   __shared__ float a[10];
27//   float v = a[i];
28// to
29//   %0 = addrspacecast [10 x float] addrspace(3)* @a to [10 x float]*
30//   %1 = gep [10 x float], [10 x float]* %0, i64 0, i64 %i
31//   %v = load float, float* %1 ; emits ld.f32
32// @a is in addrspace(3) since it's type-qualified, but its use from %1 is
33// redirected to %0 (the generic version of @a).
34//
35// The optimization implemented in this file propagates specific address spaces
36// from type-qualified variable declarations to its users. For example, it
37// optimizes the above IR to
38//   %1 = gep [10 x float] addrspace(3)* @a, i64 0, i64 %i
39//   %v = load float addrspace(3)* %1 ; emits ld.shared.f32
40// propagating the addrspace(3) from @a to %1. As the result, the NVPTX
41// codegen is able to emit ld.shared.f32 for %v.
42//
43// Address space inference works in two steps. First, it uses a data-flow
44// analysis to infer as many generic pointers as possible to point to only one
45// specific address space. In the above example, it can prove that %1 only
46// points to addrspace(3). This algorithm was published in
47//   CUDA: Compiling and optimizing for a GPU platform
48//   Chakrabarti, Grover, Aarts, Kong, Kudlur, Lin, Marathe, Murphy, Wang
49//   ICCS 2012
50//
51// Then, address space inference replaces all refinable generic pointers with
52// equivalent specific pointers.
53//
54// The major challenge of implementing this optimization is handling PHINodes,
55// which may create loops in the data flow graph. This brings two complications.
56//
57// First, the data flow analysis in Step 1 needs to be circular. For example,
58//     %generic.input = addrspacecast float addrspace(3)* %input to float*
59//   loop:
60//     %y = phi [ %generic.input, %y2 ]
61//     %y2 = getelementptr %y, 1
62//     %v = load %y2
63//     br ..., label %loop, ...
64// proving %y specific requires proving both %generic.input and %y2 specific,
65// but proving %y2 specific circles back to %y. To address this complication,
66// the data flow analysis operates on a lattice:
67//   uninitialized > specific address spaces > generic.
68// All address expressions (our implementation only considers phi, bitcast,
69// addrspacecast, and getelementptr) start with the uninitialized address space.
70// The monotone transfer function moves the address space of a pointer down a
71// lattice path from uninitialized to specific and then to generic. A join
72// operation of two different specific address spaces pushes the expression down
73// to the generic address space. The analysis completes once it reaches a fixed
74// point.
75//
76// Second, IR rewriting in Step 2 also needs to be circular. For example,
77// converting %y to addrspace(3) requires the compiler to know the converted
78// %y2, but converting %y2 needs the converted %y. To address this complication,
79// we break these cycles using "poison" placeholders. When converting an
80// instruction `I` to a new address space, if its operand `Op` is not converted
81// yet, we let `I` temporarily use `poison` and fix all the uses later.
82// For instance, our algorithm first converts %y to
83//   %y' = phi float addrspace(3)* [ %input, poison ]
84// Then, it converts %y2 to
85//   %y2' = getelementptr %y', 1
86// Finally, it fixes the poison in %y' so that
87//   %y' = phi float addrspace(3)* [ %input, %y2' ]
88//
89//===----------------------------------------------------------------------===//
90
91#include "llvm/Transforms/Scalar/InferAddressSpaces.h"
92#include "llvm/ADT/ArrayRef.h"
93#include "llvm/ADT/DenseMap.h"
94#include "llvm/ADT/DenseSet.h"
95#include "llvm/ADT/SetVector.h"
96#include "llvm/ADT/SmallVector.h"
97#include "llvm/Analysis/AssumptionCache.h"
98#include "llvm/Analysis/TargetTransformInfo.h"
99#include "llvm/Analysis/ValueTracking.h"
100#include "llvm/IR/BasicBlock.h"
101#include "llvm/IR/Constant.h"
102#include "llvm/IR/Constants.h"
103#include "llvm/IR/Dominators.h"
104#include "llvm/IR/Function.h"
105#include "llvm/IR/IRBuilder.h"
106#include "llvm/IR/InstIterator.h"
107#include "llvm/IR/Instruction.h"
108#include "llvm/IR/Instructions.h"
109#include "llvm/IR/IntrinsicInst.h"
110#include "llvm/IR/Intrinsics.h"
111#include "llvm/IR/LLVMContext.h"
112#include "llvm/IR/Operator.h"
113#include "llvm/IR/PassManager.h"
114#include "llvm/IR/Type.h"
115#include "llvm/IR/Use.h"
116#include "llvm/IR/User.h"
117#include "llvm/IR/Value.h"
118#include "llvm/IR/ValueHandle.h"
119#include "llvm/InitializePasses.h"
120#include "llvm/Pass.h"
121#include "llvm/Support/Casting.h"
122#include "llvm/Support/CommandLine.h"
123#include "llvm/Support/Compiler.h"
124#include "llvm/Support/Debug.h"
125#include "llvm/Support/ErrorHandling.h"
126#include "llvm/Support/raw_ostream.h"
127#include "llvm/Transforms/Scalar.h"
128#include "llvm/Transforms/Utils/Local.h"
129#include "llvm/Transforms/Utils/ValueMapper.h"
130#include <cassert>
131#include <iterator>
132#include <limits>
133#include <utility>
134#include <vector>
135
136#define DEBUG_TYPE "infer-address-spaces"
137
138using namespace llvm;
139
140static cl::opt<bool> AssumeDefaultIsFlatAddressSpace(
141    "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden,
142    cl::desc("The default address space is assumed as the flat address space. "
143             "This is mainly for test purpose."));
144
145static const unsigned UninitializedAddressSpace =
146    std::numeric_limits<unsigned>::max();
147
148namespace {
149
150using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>;
151// Different from ValueToAddrSpaceMapTy, where a new addrspace is inferred on
152// the *def* of a value, PredicatedAddrSpaceMapTy is map where a new
153// addrspace is inferred on the *use* of a pointer. This map is introduced to
154// infer addrspace from the addrspace predicate assumption built from assume
155// intrinsic. In that scenario, only specific uses (under valid assumption
156// context) could be inferred with a new addrspace.
157using PredicatedAddrSpaceMapTy =
158    DenseMap<std::pair<const Value *, const Value *>, unsigned>;
159using PostorderStackTy = llvm::SmallVector<PointerIntPair<Value *, 1, bool>, 4>;
160
161class InferAddressSpaces : public FunctionPass {
162  unsigned FlatAddrSpace = 0;
163
164public:
165  static char ID;
166
167  InferAddressSpaces()
168      : FunctionPass(ID), FlatAddrSpace(UninitializedAddressSpace) {
169    initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry());
170  }
171  InferAddressSpaces(unsigned AS) : FunctionPass(ID), FlatAddrSpace(AS) {
172    initializeInferAddressSpacesPass(*PassRegistry::getPassRegistry());
173  }
174
175  void getAnalysisUsage(AnalysisUsage &AU) const override {
176    AU.setPreservesCFG();
177    AU.addPreserved<DominatorTreeWrapperPass>();
178    AU.addRequired<AssumptionCacheTracker>();
179    AU.addRequired<TargetTransformInfoWrapperPass>();
180  }
181
182  bool runOnFunction(Function &F) override;
183};
184
185class InferAddressSpacesImpl {
186  AssumptionCache &AC;
187  const DominatorTree *DT = nullptr;
188  const TargetTransformInfo *TTI = nullptr;
189  const DataLayout *DL = nullptr;
190
191  /// Target specific address space which uses of should be replaced if
192  /// possible.
193  unsigned FlatAddrSpace = 0;
194
195  // Try to update the address space of V. If V is updated, returns true and
196  // false otherwise.
197  bool updateAddressSpace(const Value &V,
198                          ValueToAddrSpaceMapTy &InferredAddrSpace,
199                          PredicatedAddrSpaceMapTy &PredicatedAS) const;
200
201  // Tries to infer the specific address space of each address expression in
202  // Postorder.
203  void inferAddressSpaces(ArrayRef<WeakTrackingVH> Postorder,
204                          ValueToAddrSpaceMapTy &InferredAddrSpace,
205                          PredicatedAddrSpaceMapTy &PredicatedAS) const;
206
207  bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const;
208
209  Value *cloneInstructionWithNewAddressSpace(
210      Instruction *I, unsigned NewAddrSpace,
211      const ValueToValueMapTy &ValueWithNewAddrSpace,
212      const PredicatedAddrSpaceMapTy &PredicatedAS,
213      SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
214
215  // Changes the flat address expressions in function F to point to specific
216  // address spaces if InferredAddrSpace says so. Postorder is the postorder of
217  // all flat expressions in the use-def graph of function F.
218  bool
219  rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder,
220                              const ValueToAddrSpaceMapTy &InferredAddrSpace,
221                              const PredicatedAddrSpaceMapTy &PredicatedAS,
222                              Function *F) const;
223
224  void appendsFlatAddressExpressionToPostorderStack(
225      Value *V, PostorderStackTy &PostorderStack,
226      DenseSet<Value *> &Visited) const;
227
228  bool rewriteIntrinsicOperands(IntrinsicInst *II, Value *OldV,
229                                Value *NewV) const;
230  void collectRewritableIntrinsicOperands(IntrinsicInst *II,
231                                          PostorderStackTy &PostorderStack,
232                                          DenseSet<Value *> &Visited) const;
233
234  std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const;
235
236  Value *cloneValueWithNewAddressSpace(
237      Value *V, unsigned NewAddrSpace,
238      const ValueToValueMapTy &ValueWithNewAddrSpace,
239      const PredicatedAddrSpaceMapTy &PredicatedAS,
240      SmallVectorImpl<const Use *> *PoisonUsesToFix) const;
241  unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const;
242
243  unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const;
244
245public:
246  InferAddressSpacesImpl(AssumptionCache &AC, const DominatorTree *DT,
247                         const TargetTransformInfo *TTI, unsigned FlatAddrSpace)
248      : AC(AC), DT(DT), TTI(TTI), FlatAddrSpace(FlatAddrSpace) {}
249  bool run(Function &F);
250};
251
252} // end anonymous namespace
253
254char InferAddressSpaces::ID = 0;
255
256INITIALIZE_PASS_BEGIN(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
257                      false, false)
258INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
259INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
260INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
261                    false, false)
262
263static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) {
264  assert(Ty->isPtrOrPtrVectorTy());
265  PointerType *NPT = PointerType::get(Ty->getContext(), NewAddrSpace);
266  return Ty->getWithNewType(NPT);
267}
268
269// Check whether that's no-op pointer bicast using a pair of
270// `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over
271// different address spaces.
272static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL,
273                                 const TargetTransformInfo *TTI) {
274  assert(I2P->getOpcode() == Instruction::IntToPtr);
275  auto *P2I = dyn_cast<Operator>(I2P->getOperand(0));
276  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt)
277    return false;
278  // Check it's really safe to treat that pair of `ptrtoint`/`inttoptr` as a
279  // no-op cast. Besides checking both of them are no-op casts, as the
280  // reinterpreted pointer may be used in other pointer arithmetic, we also
281  // need to double-check that through the target-specific hook. That ensures
282  // the underlying target also agrees that's a no-op address space cast and
283  // pointer bits are preserved.
284  // The current IR spec doesn't have clear rules on address space casts,
285  // especially a clear definition for pointer bits in non-default address
286  // spaces. It would be undefined if that pointer is dereferenced after an
287  // invalid reinterpret cast. Also, due to the unclearness for the meaning of
288  // bits in non-default address spaces in the current spec, the pointer
289  // arithmetic may also be undefined after invalid pointer reinterpret cast.
290  // However, as we confirm through the target hooks that it's a no-op
291  // addrspacecast, it doesn't matter since the bits should be the same.
292  unsigned P2IOp0AS = P2I->getOperand(0)->getType()->getPointerAddressSpace();
293  unsigned I2PAS = I2P->getType()->getPointerAddressSpace();
294  return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()),
295                              I2P->getOperand(0)->getType(), I2P->getType(),
296                              DL) &&
297         CastInst::isNoopCast(Instruction::CastOps(P2I->getOpcode()),
298                              P2I->getOperand(0)->getType(), P2I->getType(),
299                              DL) &&
300         (P2IOp0AS == I2PAS || TTI->isNoopAddrSpaceCast(P2IOp0AS, I2PAS));
301}
302
303// Returns true if V is an address expression.
304// TODO: Currently, we consider only phi, bitcast, addrspacecast, and
305// getelementptr operators.
306static bool isAddressExpression(const Value &V, const DataLayout &DL,
307                                const TargetTransformInfo *TTI) {
308  const Operator *Op = dyn_cast<Operator>(&V);
309  if (!Op)
310    return false;
311
312  switch (Op->getOpcode()) {
313  case Instruction::PHI:
314    assert(Op->getType()->isPtrOrPtrVectorTy());
315    return true;
316  case Instruction::BitCast:
317  case Instruction::AddrSpaceCast:
318  case Instruction::GetElementPtr:
319    return true;
320  case Instruction::Select:
321    return Op->getType()->isPtrOrPtrVectorTy();
322  case Instruction::Call: {
323    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V);
324    return II && II->getIntrinsicID() == Intrinsic::ptrmask;
325  }
326  case Instruction::IntToPtr:
327    return isNoopPtrIntCastPair(Op, DL, TTI);
328  default:
329    // That value is an address expression if it has an assumed address space.
330    return TTI->getAssumedAddrSpace(&V) != UninitializedAddressSpace;
331  }
332}
333
334// Returns the pointer operands of V.
335//
336// Precondition: V is an address expression.
337static SmallVector<Value *, 2>
338getPointerOperands(const Value &V, const DataLayout &DL,
339                   const TargetTransformInfo *TTI) {
340  const Operator &Op = cast<Operator>(V);
341  switch (Op.getOpcode()) {
342  case Instruction::PHI: {
343    auto IncomingValues = cast<PHINode>(Op).incoming_values();
344    return {IncomingValues.begin(), IncomingValues.end()};
345  }
346  case Instruction::BitCast:
347  case Instruction::AddrSpaceCast:
348  case Instruction::GetElementPtr:
349    return {Op.getOperand(0)};
350  case Instruction::Select:
351    return {Op.getOperand(1), Op.getOperand(2)};
352  case Instruction::Call: {
353    const IntrinsicInst &II = cast<IntrinsicInst>(Op);
354    assert(II.getIntrinsicID() == Intrinsic::ptrmask &&
355           "unexpected intrinsic call");
356    return {II.getArgOperand(0)};
357  }
358  case Instruction::IntToPtr: {
359    assert(isNoopPtrIntCastPair(&Op, DL, TTI));
360    auto *P2I = cast<Operator>(Op.getOperand(0));
361    return {P2I->getOperand(0)};
362  }
363  default:
364    llvm_unreachable("Unexpected instruction type.");
365  }
366}
367
368bool InferAddressSpacesImpl::rewriteIntrinsicOperands(IntrinsicInst *II,
369                                                      Value *OldV,
370                                                      Value *NewV) const {
371  Module *M = II->getParent()->getParent()->getParent();
372
373  switch (II->getIntrinsicID()) {
374  case Intrinsic::objectsize: {
375    Type *DestTy = II->getType();
376    Type *SrcTy = NewV->getType();
377    Function *NewDecl =
378        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy});
379    II->setArgOperand(0, NewV);
380    II->setCalledFunction(NewDecl);
381    return true;
382  }
383  case Intrinsic::ptrmask:
384    // This is handled as an address expression, not as a use memory operation.
385    return false;
386  case Intrinsic::masked_gather: {
387    Type *RetTy = II->getType();
388    Type *NewPtrTy = NewV->getType();
389    Function *NewDecl =
390        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy});
391    II->setArgOperand(0, NewV);
392    II->setCalledFunction(NewDecl);
393    return true;
394  }
395  case Intrinsic::masked_scatter: {
396    Type *ValueTy = II->getOperand(0)->getType();
397    Type *NewPtrTy = NewV->getType();
398    Function *NewDecl =
399        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy});
400    II->setArgOperand(1, NewV);
401    II->setCalledFunction(NewDecl);
402    return true;
403  }
404  default: {
405    Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
406    if (!Rewrite)
407      return false;
408    if (Rewrite != II)
409      II->replaceAllUsesWith(Rewrite);
410    return true;
411  }
412  }
413}
414
415void InferAddressSpacesImpl::collectRewritableIntrinsicOperands(
416    IntrinsicInst *II, PostorderStackTy &PostorderStack,
417    DenseSet<Value *> &Visited) const {
418  auto IID = II->getIntrinsicID();
419  switch (IID) {
420  case Intrinsic::ptrmask:
421  case Intrinsic::objectsize:
422    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
423                                                 PostorderStack, Visited);
424    break;
425  case Intrinsic::masked_gather:
426    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
427                                                 PostorderStack, Visited);
428    break;
429  case Intrinsic::masked_scatter:
430    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1),
431                                                 PostorderStack, Visited);
432    break;
433  default:
434    SmallVector<int, 2> OpIndexes;
435    if (TTI->collectFlatAddressOperands(OpIndexes, IID)) {
436      for (int Idx : OpIndexes) {
437        appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(Idx),
438                                                     PostorderStack, Visited);
439      }
440    }
441    break;
442  }
443}
444
445// Returns all flat address expressions in function F. The elements are
446// If V is an unvisited flat address expression, appends V to PostorderStack
447// and marks it as visited.
448void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack(
449    Value *V, PostorderStackTy &PostorderStack,
450    DenseSet<Value *> &Visited) const {
451  assert(V->getType()->isPtrOrPtrVectorTy());
452
453  // Generic addressing expressions may be hidden in nested constant
454  // expressions.
455  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
456    // TODO: Look in non-address parts, like icmp operands.
457    if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
458      PostorderStack.emplace_back(CE, false);
459
460    return;
461  }
462
463  if (V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
464      isAddressExpression(*V, *DL, TTI)) {
465    if (Visited.insert(V).second) {
466      PostorderStack.emplace_back(V, false);
467
468      Operator *Op = cast<Operator>(V);
469      for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) {
470        if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) {
471          if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)
472            PostorderStack.emplace_back(CE, false);
473        }
474      }
475    }
476  }
477}
478
479// Returns all flat address expressions in function F. The elements are ordered
480// in postorder.
481std::vector<WeakTrackingVH>
482InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const {
483  // This function implements a non-recursive postorder traversal of a partial
484  // use-def graph of function F.
485  PostorderStackTy PostorderStack;
486  // The set of visited expressions.
487  DenseSet<Value *> Visited;
488
489  auto PushPtrOperand = [&](Value *Ptr) {
490    appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, Visited);
491  };
492
493  // Look at operations that may be interesting accelerate by moving to a known
494  // address space. We aim at generating after loads and stores, but pure
495  // addressing calculations may also be faster.
496  for (Instruction &I : instructions(F)) {
497    if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
498      PushPtrOperand(GEP->getPointerOperand());
499    } else if (auto *LI = dyn_cast<LoadInst>(&I))
500      PushPtrOperand(LI->getPointerOperand());
501    else if (auto *SI = dyn_cast<StoreInst>(&I))
502      PushPtrOperand(SI->getPointerOperand());
503    else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I))
504      PushPtrOperand(RMW->getPointerOperand());
505    else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I))
506      PushPtrOperand(CmpX->getPointerOperand());
507    else if (auto *MI = dyn_cast<MemIntrinsic>(&I)) {
508      // For memset/memcpy/memmove, any pointer operand can be replaced.
509      PushPtrOperand(MI->getRawDest());
510
511      // Handle 2nd operand for memcpy/memmove.
512      if (auto *MTI = dyn_cast<MemTransferInst>(MI))
513        PushPtrOperand(MTI->getRawSource());
514    } else if (auto *II = dyn_cast<IntrinsicInst>(&I))
515      collectRewritableIntrinsicOperands(II, PostorderStack, Visited);
516    else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) {
517      if (Cmp->getOperand(0)->getType()->isPtrOrPtrVectorTy()) {
518        PushPtrOperand(Cmp->getOperand(0));
519        PushPtrOperand(Cmp->getOperand(1));
520      }
521    } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) {
522      PushPtrOperand(ASC->getPointerOperand());
523    } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
524      if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
525        PushPtrOperand(cast<Operator>(I2P->getOperand(0))->getOperand(0));
526    } else if (auto *RI = dyn_cast<ReturnInst>(&I)) {
527      if (auto *RV = RI->getReturnValue();
528          RV && RV->getType()->isPtrOrPtrVectorTy())
529        PushPtrOperand(RV);
530    }
531  }
532
533  std::vector<WeakTrackingVH> Postorder; // The resultant postorder.
534  while (!PostorderStack.empty()) {
535    Value *TopVal = PostorderStack.back().getPointer();
536    // If the operands of the expression on the top are already explored,
537    // adds that expression to the resultant postorder.
538    if (PostorderStack.back().getInt()) {
539      if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace)
540        Postorder.push_back(TopVal);
541      PostorderStack.pop_back();
542      continue;
543    }
544    // Otherwise, adds its operands to the stack and explores them.
545    PostorderStack.back().setInt(true);
546    // Skip values with an assumed address space.
547    if (TTI->getAssumedAddrSpace(TopVal) == UninitializedAddressSpace) {
548      for (Value *PtrOperand : getPointerOperands(*TopVal, *DL, TTI)) {
549        appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack,
550                                                     Visited);
551      }
552    }
553  }
554  return Postorder;
555}
556
557// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone
558// of OperandUse.get() in the new address space. If the clone is not ready yet,
559// returns poison in the new address space as a placeholder.
560static Value *operandWithNewAddressSpaceOrCreatePoison(
561    const Use &OperandUse, unsigned NewAddrSpace,
562    const ValueToValueMapTy &ValueWithNewAddrSpace,
563    const PredicatedAddrSpaceMapTy &PredicatedAS,
564    SmallVectorImpl<const Use *> *PoisonUsesToFix) {
565  Value *Operand = OperandUse.get();
566
567  Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace);
568
569  if (Constant *C = dyn_cast<Constant>(Operand))
570    return ConstantExpr::getAddrSpaceCast(C, NewPtrTy);
571
572  if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand))
573    return NewOperand;
574
575  Instruction *Inst = cast<Instruction>(OperandUse.getUser());
576  auto I = PredicatedAS.find(std::make_pair(Inst, Operand));
577  if (I != PredicatedAS.end()) {
578    // Insert an addrspacecast on that operand before the user.
579    unsigned NewAS = I->second;
580    Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS);
581    auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy);
582    NewI->insertBefore(Inst);
583    NewI->setDebugLoc(Inst->getDebugLoc());
584    return NewI;
585  }
586
587  PoisonUsesToFix->push_back(&OperandUse);
588  return PoisonValue::get(NewPtrTy);
589}
590
591// Returns a clone of `I` with its operands converted to those specified in
592// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an
593// operand whose address space needs to be modified might not exist in
594// ValueWithNewAddrSpace. In that case, uses poison as a placeholder operand and
595// adds that operand use to PoisonUsesToFix so that caller can fix them later.
596//
597// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast
598// from a pointer whose type already matches. Therefore, this function returns a
599// Value* instead of an Instruction*.
600//
601// This may also return nullptr in the case the instruction could not be
602// rewritten.
603Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace(
604    Instruction *I, unsigned NewAddrSpace,
605    const ValueToValueMapTy &ValueWithNewAddrSpace,
606    const PredicatedAddrSpaceMapTy &PredicatedAS,
607    SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
608  Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace);
609
610  if (I->getOpcode() == Instruction::AddrSpaceCast) {
611    Value *Src = I->getOperand(0);
612    // Because `I` is flat, the source address space must be specific.
613    // Therefore, the inferred address space must be the source space, according
614    // to our algorithm.
615    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
616    if (Src->getType() != NewPtrType)
617      return new BitCastInst(Src, NewPtrType);
618    return Src;
619  }
620
621  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
622    // Technically the intrinsic ID is a pointer typed argument, so specially
623    // handle calls early.
624    assert(II->getIntrinsicID() == Intrinsic::ptrmask);
625    Value *NewPtr = operandWithNewAddressSpaceOrCreatePoison(
626        II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace,
627        PredicatedAS, PoisonUsesToFix);
628    Value *Rewrite =
629        TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr);
630    if (Rewrite) {
631      assert(Rewrite != II && "cannot modify this pointer operation in place");
632      return Rewrite;
633    }
634
635    return nullptr;
636  }
637
638  unsigned AS = TTI->getAssumedAddrSpace(I);
639  if (AS != UninitializedAddressSpace) {
640    // For the assumed address space, insert an `addrspacecast` to make that
641    // explicit.
642    Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS);
643    auto *NewI = new AddrSpaceCastInst(I, NewPtrTy);
644    NewI->insertAfter(I);
645    return NewI;
646  }
647
648  // Computes the converted pointer operands.
649  SmallVector<Value *, 4> NewPointerOperands;
650  for (const Use &OperandUse : I->operands()) {
651    if (!OperandUse.get()->getType()->isPtrOrPtrVectorTy())
652      NewPointerOperands.push_back(nullptr);
653    else
654      NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreatePoison(
655          OperandUse, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS,
656          PoisonUsesToFix));
657  }
658
659  switch (I->getOpcode()) {
660  case Instruction::BitCast:
661    return new BitCastInst(NewPointerOperands[0], NewPtrType);
662  case Instruction::PHI: {
663    assert(I->getType()->isPtrOrPtrVectorTy());
664    PHINode *PHI = cast<PHINode>(I);
665    PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues());
666    for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) {
667      unsigned OperandNo = PHINode::getOperandNumForIncomingValue(Index);
668      NewPHI->addIncoming(NewPointerOperands[OperandNo],
669                          PHI->getIncomingBlock(Index));
670    }
671    return NewPHI;
672  }
673  case Instruction::GetElementPtr: {
674    GetElementPtrInst *GEP = cast<GetElementPtrInst>(I);
675    GetElementPtrInst *NewGEP = GetElementPtrInst::Create(
676        GEP->getSourceElementType(), NewPointerOperands[0],
677        SmallVector<Value *, 4>(GEP->indices()));
678    NewGEP->setIsInBounds(GEP->isInBounds());
679    return NewGEP;
680  }
681  case Instruction::Select:
682    assert(I->getType()->isPtrOrPtrVectorTy());
683    return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
684                              NewPointerOperands[2], "", nullptr, I);
685  case Instruction::IntToPtr: {
686    assert(isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI));
687    Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0);
688    if (Src->getType() == NewPtrType)
689      return Src;
690
691    // If we had a no-op inttoptr/ptrtoint pair, we may still have inferred a
692    // source address space from a generic pointer source need to insert a cast
693    // back.
694    return CastInst::CreatePointerBitCastOrAddrSpaceCast(Src, NewPtrType);
695  }
696  default:
697    llvm_unreachable("Unexpected opcode");
698  }
699}
700
701// Similar to cloneInstructionWithNewAddressSpace, returns a clone of the
702// constant expression `CE` with its operands replaced as specified in
703// ValueWithNewAddrSpace.
704static Value *cloneConstantExprWithNewAddressSpace(
705    ConstantExpr *CE, unsigned NewAddrSpace,
706    const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
707    const TargetTransformInfo *TTI) {
708  Type *TargetType =
709      CE->getType()->isPtrOrPtrVectorTy()
710          ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace)
711          : CE->getType();
712
713  if (CE->getOpcode() == Instruction::AddrSpaceCast) {
714    // Because CE is flat, the source address space must be specific.
715    // Therefore, the inferred address space must be the source space according
716    // to our algorithm.
717    assert(CE->getOperand(0)->getType()->getPointerAddressSpace() ==
718           NewAddrSpace);
719    return ConstantExpr::getBitCast(CE->getOperand(0), TargetType);
720  }
721
722  if (CE->getOpcode() == Instruction::BitCast) {
723    if (Value *NewOperand = ValueWithNewAddrSpace.lookup(CE->getOperand(0)))
724      return ConstantExpr::getBitCast(cast<Constant>(NewOperand), TargetType);
725    return ConstantExpr::getAddrSpaceCast(CE, TargetType);
726  }
727
728  if (CE->getOpcode() == Instruction::IntToPtr) {
729    assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI));
730    Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
731    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
732    return ConstantExpr::getBitCast(Src, TargetType);
733  }
734
735  // Computes the operands of the new constant expression.
736  bool IsNew = false;
737  SmallVector<Constant *, 4> NewOperands;
738  for (unsigned Index = 0; Index < CE->getNumOperands(); ++Index) {
739    Constant *Operand = CE->getOperand(Index);
740    // If the address space of `Operand` needs to be modified, the new operand
741    // with the new address space should already be in ValueWithNewAddrSpace
742    // because (1) the constant expressions we consider (i.e. addrspacecast,
743    // bitcast, and getelementptr) do not incur cycles in the data flow graph
744    // and (2) this function is called on constant expressions in postorder.
745    if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) {
746      IsNew = true;
747      NewOperands.push_back(cast<Constant>(NewOperand));
748      continue;
749    }
750    if (auto *CExpr = dyn_cast<ConstantExpr>(Operand))
751      if (Value *NewOperand = cloneConstantExprWithNewAddressSpace(
752              CExpr, NewAddrSpace, ValueWithNewAddrSpace, DL, TTI)) {
753        IsNew = true;
754        NewOperands.push_back(cast<Constant>(NewOperand));
755        continue;
756      }
757    // Otherwise, reuses the old operand.
758    NewOperands.push_back(Operand);
759  }
760
761  // If !IsNew, we will replace the Value with itself. However, replaced values
762  // are assumed to wrapped in an addrspacecast cast later so drop it now.
763  if (!IsNew)
764    return nullptr;
765
766  if (CE->getOpcode() == Instruction::GetElementPtr) {
767    // Needs to specify the source type while constructing a getelementptr
768    // constant expression.
769    return CE->getWithOperands(NewOperands, TargetType, /*OnlyIfReduced=*/false,
770                               cast<GEPOperator>(CE)->getSourceElementType());
771  }
772
773  return CE->getWithOperands(NewOperands, TargetType);
774}
775
776// Returns a clone of the value `V`, with its operands replaced as specified in
777// ValueWithNewAddrSpace. This function is called on every flat address
778// expression whose address space needs to be modified, in postorder.
779//
780// See cloneInstructionWithNewAddressSpace for the meaning of PoisonUsesToFix.
781Value *InferAddressSpacesImpl::cloneValueWithNewAddressSpace(
782    Value *V, unsigned NewAddrSpace,
783    const ValueToValueMapTy &ValueWithNewAddrSpace,
784    const PredicatedAddrSpaceMapTy &PredicatedAS,
785    SmallVectorImpl<const Use *> *PoisonUsesToFix) const {
786  // All values in Postorder are flat address expressions.
787  assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
788         isAddressExpression(*V, *DL, TTI));
789
790  if (Instruction *I = dyn_cast<Instruction>(V)) {
791    Value *NewV = cloneInstructionWithNewAddressSpace(
792        I, NewAddrSpace, ValueWithNewAddrSpace, PredicatedAS, PoisonUsesToFix);
793    if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) {
794      if (NewI->getParent() == nullptr) {
795        NewI->insertBefore(I);
796        NewI->takeName(I);
797        NewI->setDebugLoc(I->getDebugLoc());
798      }
799    }
800    return NewV;
801  }
802
803  return cloneConstantExprWithNewAddressSpace(
804      cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace, DL, TTI);
805}
806
807// Defines the join operation on the address space lattice (see the file header
808// comments).
809unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1,
810                                                   unsigned AS2) const {
811  if (AS1 == FlatAddrSpace || AS2 == FlatAddrSpace)
812    return FlatAddrSpace;
813
814  if (AS1 == UninitializedAddressSpace)
815    return AS2;
816  if (AS2 == UninitializedAddressSpace)
817    return AS1;
818
819  // The join of two different specific address spaces is flat.
820  return (AS1 == AS2) ? AS1 : FlatAddrSpace;
821}
822
823bool InferAddressSpacesImpl::run(Function &F) {
824  DL = &F.getParent()->getDataLayout();
825
826  if (AssumeDefaultIsFlatAddressSpace)
827    FlatAddrSpace = 0;
828
829  if (FlatAddrSpace == UninitializedAddressSpace) {
830    FlatAddrSpace = TTI->getFlatAddressSpace();
831    if (FlatAddrSpace == UninitializedAddressSpace)
832      return false;
833  }
834
835  // Collects all flat address expressions in postorder.
836  std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F);
837
838  // Runs a data-flow analysis to refine the address spaces of every expression
839  // in Postorder.
840  ValueToAddrSpaceMapTy InferredAddrSpace;
841  PredicatedAddrSpaceMapTy PredicatedAS;
842  inferAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS);
843
844  // Changes the address spaces of the flat address expressions who are inferred
845  // to point to a specific address space.
846  return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS,
847                                     &F);
848}
849
850// Constants need to be tracked through RAUW to handle cases with nested
851// constant expressions, so wrap values in WeakTrackingVH.
852void InferAddressSpacesImpl::inferAddressSpaces(
853    ArrayRef<WeakTrackingVH> Postorder,
854    ValueToAddrSpaceMapTy &InferredAddrSpace,
855    PredicatedAddrSpaceMapTy &PredicatedAS) const {
856  SetVector<Value *> Worklist(Postorder.begin(), Postorder.end());
857  // Initially, all expressions are in the uninitialized address space.
858  for (Value *V : Postorder)
859    InferredAddrSpace[V] = UninitializedAddressSpace;
860
861  while (!Worklist.empty()) {
862    Value *V = Worklist.pop_back_val();
863
864    // Try to update the address space of the stack top according to the
865    // address spaces of its operands.
866    if (!updateAddressSpace(*V, InferredAddrSpace, PredicatedAS))
867      continue;
868
869    for (Value *User : V->users()) {
870      // Skip if User is already in the worklist.
871      if (Worklist.count(User))
872        continue;
873
874      auto Pos = InferredAddrSpace.find(User);
875      // Our algorithm only updates the address spaces of flat address
876      // expressions, which are those in InferredAddrSpace.
877      if (Pos == InferredAddrSpace.end())
878        continue;
879
880      // Function updateAddressSpace moves the address space down a lattice
881      // path. Therefore, nothing to do if User is already inferred as flat (the
882      // bottom element in the lattice).
883      if (Pos->second == FlatAddrSpace)
884        continue;
885
886      Worklist.insert(User);
887    }
888  }
889}
890
891unsigned InferAddressSpacesImpl::getPredicatedAddrSpace(const Value &V,
892                                                        Value *Opnd) const {
893  const Instruction *I = dyn_cast<Instruction>(&V);
894  if (!I)
895    return UninitializedAddressSpace;
896
897  Opnd = Opnd->stripInBoundsOffsets();
898  for (auto &AssumeVH : AC.assumptionsFor(Opnd)) {
899    if (!AssumeVH)
900      continue;
901    CallInst *CI = cast<CallInst>(AssumeVH);
902    if (!isValidAssumeForContext(CI, I, DT))
903      continue;
904
905    const Value *Ptr;
906    unsigned AS;
907    std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(CI->getArgOperand(0));
908    if (Ptr)
909      return AS;
910  }
911
912  return UninitializedAddressSpace;
913}
914
915bool InferAddressSpacesImpl::updateAddressSpace(
916    const Value &V, ValueToAddrSpaceMapTy &InferredAddrSpace,
917    PredicatedAddrSpaceMapTy &PredicatedAS) const {
918  assert(InferredAddrSpace.count(&V));
919
920  LLVM_DEBUG(dbgs() << "Updating the address space of\n  " << V << '\n');
921
922  // The new inferred address space equals the join of the address spaces
923  // of all its pointer operands.
924  unsigned NewAS = UninitializedAddressSpace;
925
926  const Operator &Op = cast<Operator>(V);
927  if (Op.getOpcode() == Instruction::Select) {
928    Value *Src0 = Op.getOperand(1);
929    Value *Src1 = Op.getOperand(2);
930
931    auto I = InferredAddrSpace.find(Src0);
932    unsigned Src0AS = (I != InferredAddrSpace.end())
933                          ? I->second
934                          : Src0->getType()->getPointerAddressSpace();
935
936    auto J = InferredAddrSpace.find(Src1);
937    unsigned Src1AS = (J != InferredAddrSpace.end())
938                          ? J->second
939                          : Src1->getType()->getPointerAddressSpace();
940
941    auto *C0 = dyn_cast<Constant>(Src0);
942    auto *C1 = dyn_cast<Constant>(Src1);
943
944    // If one of the inputs is a constant, we may be able to do a constant
945    // addrspacecast of it. Defer inferring the address space until the input
946    // address space is known.
947    if ((C1 && Src0AS == UninitializedAddressSpace) ||
948        (C0 && Src1AS == UninitializedAddressSpace))
949      return false;
950
951    if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS))
952      NewAS = Src1AS;
953    else if (C1 && isSafeToCastConstAddrSpace(C1, Src0AS))
954      NewAS = Src0AS;
955    else
956      NewAS = joinAddressSpaces(Src0AS, Src1AS);
957  } else {
958    unsigned AS = TTI->getAssumedAddrSpace(&V);
959    if (AS != UninitializedAddressSpace) {
960      // Use the assumed address space directly.
961      NewAS = AS;
962    } else {
963      // Otherwise, infer the address space from its pointer operands.
964      for (Value *PtrOperand : getPointerOperands(V, *DL, TTI)) {
965        auto I = InferredAddrSpace.find(PtrOperand);
966        unsigned OperandAS;
967        if (I == InferredAddrSpace.end()) {
968          OperandAS = PtrOperand->getType()->getPointerAddressSpace();
969          if (OperandAS == FlatAddrSpace) {
970            // Check AC for assumption dominating V.
971            unsigned AS = getPredicatedAddrSpace(V, PtrOperand);
972            if (AS != UninitializedAddressSpace) {
973              LLVM_DEBUG(dbgs()
974                         << "  deduce operand AS from the predicate addrspace "
975                         << AS << '\n');
976              OperandAS = AS;
977              // Record this use with the predicated AS.
978              PredicatedAS[std::make_pair(&V, PtrOperand)] = OperandAS;
979            }
980          }
981        } else
982          OperandAS = I->second;
983
984        // join(flat, *) = flat. So we can break if NewAS is already flat.
985        NewAS = joinAddressSpaces(NewAS, OperandAS);
986        if (NewAS == FlatAddrSpace)
987          break;
988      }
989    }
990  }
991
992  unsigned OldAS = InferredAddrSpace.lookup(&V);
993  assert(OldAS != FlatAddrSpace);
994  if (OldAS == NewAS)
995    return false;
996
997  // If any updates are made, grabs its users to the worklist because
998  // their address spaces can also be possibly updated.
999  LLVM_DEBUG(dbgs() << "  to " << NewAS << '\n');
1000  InferredAddrSpace[&V] = NewAS;
1001  return true;
1002}
1003
1004/// \p returns true if \p U is the pointer operand of a memory instruction with
1005/// a single pointer operand that can have its address space changed by simply
1006/// mutating the use to a new value. If the memory instruction is volatile,
1007/// return true only if the target allows the memory instruction to be volatile
1008/// in the new address space.
1009static bool isSimplePointerUseValidToReplace(const TargetTransformInfo &TTI,
1010                                             Use &U, unsigned AddrSpace) {
1011  User *Inst = U.getUser();
1012  unsigned OpNo = U.getOperandNo();
1013  bool VolatileIsAllowed = false;
1014  if (auto *I = dyn_cast<Instruction>(Inst))
1015    VolatileIsAllowed = TTI.hasVolatileVariant(I, AddrSpace);
1016
1017  if (auto *LI = dyn_cast<LoadInst>(Inst))
1018    return OpNo == LoadInst::getPointerOperandIndex() &&
1019           (VolatileIsAllowed || !LI->isVolatile());
1020
1021  if (auto *SI = dyn_cast<StoreInst>(Inst))
1022    return OpNo == StoreInst::getPointerOperandIndex() &&
1023           (VolatileIsAllowed || !SI->isVolatile());
1024
1025  if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst))
1026    return OpNo == AtomicRMWInst::getPointerOperandIndex() &&
1027           (VolatileIsAllowed || !RMW->isVolatile());
1028
1029  if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst))
1030    return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() &&
1031           (VolatileIsAllowed || !CmpX->isVolatile());
1032
1033  return false;
1034}
1035
1036/// Update memory intrinsic uses that require more complex processing than
1037/// simple memory instructions. These require re-mangling and may have multiple
1038/// pointer operands.
1039static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV,
1040                                     Value *NewV) {
1041  IRBuilder<> B(MI);
1042  MDNode *TBAA = MI->getMetadata(LLVMContext::MD_tbaa);
1043  MDNode *ScopeMD = MI->getMetadata(LLVMContext::MD_alias_scope);
1044  MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias);
1045
1046  if (auto *MSI = dyn_cast<MemSetInst>(MI)) {
1047    B.CreateMemSet(NewV, MSI->getValue(), MSI->getLength(), MSI->getDestAlign(),
1048                   false, // isVolatile
1049                   TBAA, ScopeMD, NoAliasMD);
1050  } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) {
1051    Value *Src = MTI->getRawSource();
1052    Value *Dest = MTI->getRawDest();
1053
1054    // Be careful in case this is a self-to-self copy.
1055    if (Src == OldV)
1056      Src = NewV;
1057
1058    if (Dest == OldV)
1059      Dest = NewV;
1060
1061    if (isa<MemCpyInlineInst>(MTI)) {
1062      MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct);
1063      B.CreateMemCpyInline(Dest, MTI->getDestAlign(), Src,
1064                           MTI->getSourceAlign(), MTI->getLength(),
1065                           false, // isVolatile
1066                           TBAA, TBAAStruct, ScopeMD, NoAliasMD);
1067    } else if (isa<MemCpyInst>(MTI)) {
1068      MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct);
1069      B.CreateMemCpy(Dest, MTI->getDestAlign(), Src, MTI->getSourceAlign(),
1070                     MTI->getLength(),
1071                     false, // isVolatile
1072                     TBAA, TBAAStruct, ScopeMD, NoAliasMD);
1073    } else {
1074      assert(isa<MemMoveInst>(MTI));
1075      B.CreateMemMove(Dest, MTI->getDestAlign(), Src, MTI->getSourceAlign(),
1076                      MTI->getLength(),
1077                      false, // isVolatile
1078                      TBAA, ScopeMD, NoAliasMD);
1079    }
1080  } else
1081    llvm_unreachable("unhandled MemIntrinsic");
1082
1083  MI->eraseFromParent();
1084  return true;
1085}
1086
1087// \p returns true if it is OK to change the address space of constant \p C with
1088// a ConstantExpr addrspacecast.
1089bool InferAddressSpacesImpl::isSafeToCastConstAddrSpace(Constant *C,
1090                                                        unsigned NewAS) const {
1091  assert(NewAS != UninitializedAddressSpace);
1092
1093  unsigned SrcAS = C->getType()->getPointerAddressSpace();
1094  if (SrcAS == NewAS || isa<UndefValue>(C))
1095    return true;
1096
1097  // Prevent illegal casts between different non-flat address spaces.
1098  if (SrcAS != FlatAddrSpace && NewAS != FlatAddrSpace)
1099    return false;
1100
1101  if (isa<ConstantPointerNull>(C))
1102    return true;
1103
1104  if (auto *Op = dyn_cast<Operator>(C)) {
1105    // If we already have a constant addrspacecast, it should be safe to cast it
1106    // off.
1107    if (Op->getOpcode() == Instruction::AddrSpaceCast)
1108      return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)),
1109                                        NewAS);
1110
1111    if (Op->getOpcode() == Instruction::IntToPtr &&
1112        Op->getType()->getPointerAddressSpace() == FlatAddrSpace)
1113      return true;
1114  }
1115
1116  return false;
1117}
1118
1119static Value::use_iterator skipToNextUser(Value::use_iterator I,
1120                                          Value::use_iterator End) {
1121  User *CurUser = I->getUser();
1122  ++I;
1123
1124  while (I != End && I->getUser() == CurUser)
1125    ++I;
1126
1127  return I;
1128}
1129
1130bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
1131    ArrayRef<WeakTrackingVH> Postorder,
1132    const ValueToAddrSpaceMapTy &InferredAddrSpace,
1133    const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const {
1134  // For each address expression to be modified, creates a clone of it with its
1135  // pointer operands converted to the new address space. Since the pointer
1136  // operands are converted, the clone is naturally in the new address space by
1137  // construction.
1138  ValueToValueMapTy ValueWithNewAddrSpace;
1139  SmallVector<const Use *, 32> PoisonUsesToFix;
1140  for (Value *V : Postorder) {
1141    unsigned NewAddrSpace = InferredAddrSpace.lookup(V);
1142
1143    // In some degenerate cases (e.g. invalid IR in unreachable code), we may
1144    // not even infer the value to have its original address space.
1145    if (NewAddrSpace == UninitializedAddressSpace)
1146      continue;
1147
1148    if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
1149      Value *New =
1150          cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace,
1151                                        PredicatedAS, &PoisonUsesToFix);
1152      if (New)
1153        ValueWithNewAddrSpace[V] = New;
1154    }
1155  }
1156
1157  if (ValueWithNewAddrSpace.empty())
1158    return false;
1159
1160  // Fixes all the poison uses generated by cloneInstructionWithNewAddressSpace.
1161  for (const Use *PoisonUse : PoisonUsesToFix) {
1162    User *V = PoisonUse->getUser();
1163    User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V));
1164    if (!NewV)
1165      continue;
1166
1167    unsigned OperandNo = PoisonUse->getOperandNo();
1168    assert(isa<PoisonValue>(NewV->getOperand(OperandNo)));
1169    NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(PoisonUse->get()));
1170  }
1171
1172  SmallVector<Instruction *, 16> DeadInstructions;
1173  ValueToValueMapTy VMap;
1174  ValueMapper VMapper(VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
1175
1176  // Replaces the uses of the old address expressions with the new ones.
1177  for (const WeakTrackingVH &WVH : Postorder) {
1178    assert(WVH && "value was unexpectedly deleted");
1179    Value *V = WVH;
1180    Value *NewV = ValueWithNewAddrSpace.lookup(V);
1181    if (NewV == nullptr)
1182      continue;
1183
1184    LLVM_DEBUG(dbgs() << "Replacing the uses of " << *V << "\n  with\n  "
1185                      << *NewV << '\n');
1186
1187    if (Constant *C = dyn_cast<Constant>(V)) {
1188      Constant *Replace =
1189          ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), C->getType());
1190      if (C != Replace) {
1191        LLVM_DEBUG(dbgs() << "Inserting replacement const cast: " << Replace
1192                          << ": " << *Replace << '\n');
1193        SmallVector<User *, 16> WorkList;
1194        for (User *U : make_early_inc_range(C->users())) {
1195          if (auto *I = dyn_cast<Instruction>(U)) {
1196            if (I->getFunction() == F)
1197              I->replaceUsesOfWith(C, Replace);
1198          } else {
1199            WorkList.append(U->user_begin(), U->user_end());
1200          }
1201        }
1202        if (!WorkList.empty()) {
1203          VMap[C] = Replace;
1204          DenseSet<User *> Visited{WorkList.begin(), WorkList.end()};
1205          while (!WorkList.empty()) {
1206            User *U = WorkList.pop_back_val();
1207            if (auto *I = dyn_cast<Instruction>(U)) {
1208              if (I->getFunction() == F)
1209                VMapper.remapInstruction(*I);
1210              continue;
1211            }
1212            for (User *U2 : U->users())
1213              if (Visited.insert(U2).second)
1214                WorkList.push_back(U2);
1215          }
1216        }
1217        V = Replace;
1218      }
1219    }
1220
1221    Value::use_iterator I, E, Next;
1222    for (I = V->use_begin(), E = V->use_end(); I != E;) {
1223      Use &U = *I;
1224
1225      // Some users may see the same pointer operand in multiple operands. Skip
1226      // to the next instruction.
1227      I = skipToNextUser(I, E);
1228
1229      if (isSimplePointerUseValidToReplace(
1230              *TTI, U, V->getType()->getPointerAddressSpace())) {
1231        // If V is used as the pointer operand of a compatible memory operation,
1232        // sets the pointer operand to NewV. This replacement does not change
1233        // the element type, so the resultant load/store is still valid.
1234        U.set(NewV);
1235        continue;
1236      }
1237
1238      User *CurUser = U.getUser();
1239      // Skip if the current user is the new value itself.
1240      if (CurUser == NewV)
1241        continue;
1242
1243      if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
1244          CurUserI && CurUserI->getFunction() != F)
1245        continue;
1246
1247      // Handle more complex cases like intrinsic that need to be remangled.
1248      if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
1249        if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
1250          continue;
1251      }
1252
1253      if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
1254        if (rewriteIntrinsicOperands(II, V, NewV))
1255          continue;
1256      }
1257
1258      if (isa<Instruction>(CurUser)) {
1259        if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) {
1260          // If we can infer that both pointers are in the same addrspace,
1261          // transform e.g.
1262          //   %cmp = icmp eq float* %p, %q
1263          // into
1264          //   %cmp = icmp eq float addrspace(3)* %new_p, %new_q
1265
1266          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
1267          int SrcIdx = U.getOperandNo();
1268          int OtherIdx = (SrcIdx == 0) ? 1 : 0;
1269          Value *OtherSrc = Cmp->getOperand(OtherIdx);
1270
1271          if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
1272            if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
1273              Cmp->setOperand(OtherIdx, OtherNewV);
1274              Cmp->setOperand(SrcIdx, NewV);
1275              continue;
1276            }
1277          }
1278
1279          // Even if the type mismatches, we can cast the constant.
1280          if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
1281            if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
1282              Cmp->setOperand(SrcIdx, NewV);
1283              Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
1284                                            KOtherSrc, NewV->getType()));
1285              continue;
1286            }
1287          }
1288        }
1289
1290        if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
1291          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
1292          if (ASC->getDestAddressSpace() == NewAS) {
1293            ASC->replaceAllUsesWith(NewV);
1294            DeadInstructions.push_back(ASC);
1295            continue;
1296          }
1297        }
1298
1299        // Otherwise, replaces the use with flat(NewV).
1300        if (Instruction *VInst = dyn_cast<Instruction>(V)) {
1301          // Don't create a copy of the original addrspacecast.
1302          if (U == V && isa<AddrSpaceCastInst>(V))
1303            continue;
1304
1305          // Insert the addrspacecast after NewV.
1306          BasicBlock::iterator InsertPos;
1307          if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
1308            InsertPos = std::next(NewVInst->getIterator());
1309          else
1310            InsertPos = std::next(VInst->getIterator());
1311
1312          while (isa<PHINode>(InsertPos))
1313            ++InsertPos;
1314          U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos));
1315        } else {
1316          U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
1317                                               V->getType()));
1318        }
1319      }
1320    }
1321
1322    if (V->use_empty()) {
1323      if (Instruction *I = dyn_cast<Instruction>(V))
1324        DeadInstructions.push_back(I);
1325    }
1326  }
1327
1328  for (Instruction *I : DeadInstructions)
1329    RecursivelyDeleteTriviallyDeadInstructions(I);
1330
1331  return true;
1332}
1333
1334bool InferAddressSpaces::runOnFunction(Function &F) {
1335  if (skipFunction(F))
1336    return false;
1337
1338  auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
1339  DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
1340  return InferAddressSpacesImpl(
1341             getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), DT,
1342             &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
1343             FlatAddrSpace)
1344      .run(F);
1345}
1346
1347FunctionPass *llvm::createInferAddressSpacesPass(unsigned AddressSpace) {
1348  return new InferAddressSpaces(AddressSpace);
1349}
1350
1351InferAddressSpacesPass::InferAddressSpacesPass()
1352    : FlatAddrSpace(UninitializedAddressSpace) {}
1353InferAddressSpacesPass::InferAddressSpacesPass(unsigned AddressSpace)
1354    : FlatAddrSpace(AddressSpace) {}
1355
1356PreservedAnalyses InferAddressSpacesPass::run(Function &F,
1357                                              FunctionAnalysisManager &AM) {
1358  bool Changed =
1359      InferAddressSpacesImpl(AM.getResult<AssumptionAnalysis>(F),
1360                             AM.getCachedResult<DominatorTreeAnalysis>(F),
1361                             &AM.getResult<TargetIRAnalysis>(F), FlatAddrSpace)
1362          .run(F);
1363  if (Changed) {
1364    PreservedAnalyses PA;
1365    PA.preserveSet<CFGAnalyses>();
1366    PA.preserve<DominatorTreeAnalysis>();
1367    return PA;
1368  }
1369  return PreservedAnalyses::all();
1370}
1371