NVPTXAsmPrinter.cpp revision 360784
1//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a printer that converts from our internal representation
10// of machine-dependent LLVM code to NVPTX assembly language.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXAsmPrinter.h"
15#include "MCTargetDesc/NVPTXBaseInfo.h"
16#include "MCTargetDesc/NVPTXInstPrinter.h"
17#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18#include "MCTargetDesc/NVPTXTargetStreamer.h"
19#include "NVPTX.h"
20#include "NVPTXMCExpr.h"
21#include "NVPTXMachineFunctionInfo.h"
22#include "NVPTXRegisterInfo.h"
23#include "NVPTXSubtarget.h"
24#include "NVPTXTargetMachine.h"
25#include "NVPTXUtilities.h"
26#include "TargetInfo/NVPTXTargetInfo.h"
27#include "cl_common_defines.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/DenseSet.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/SmallVector.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/ADT/StringRef.h"
36#include "llvm/ADT/Triple.h"
37#include "llvm/ADT/Twine.h"
38#include "llvm/Analysis/ConstantFolding.h"
39#include "llvm/CodeGen/Analysis.h"
40#include "llvm/CodeGen/MachineBasicBlock.h"
41#include "llvm/CodeGen/MachineFrameInfo.h"
42#include "llvm/CodeGen/MachineFunction.h"
43#include "llvm/CodeGen/MachineInstr.h"
44#include "llvm/CodeGen/MachineLoopInfo.h"
45#include "llvm/CodeGen/MachineModuleInfo.h"
46#include "llvm/CodeGen/MachineOperand.h"
47#include "llvm/CodeGen/MachineRegisterInfo.h"
48#include "llvm/CodeGen/TargetLowering.h"
49#include "llvm/CodeGen/TargetRegisterInfo.h"
50#include "llvm/CodeGen/ValueTypes.h"
51#include "llvm/IR/Attributes.h"
52#include "llvm/IR/BasicBlock.h"
53#include "llvm/IR/Constant.h"
54#include "llvm/IR/Constants.h"
55#include "llvm/IR/DataLayout.h"
56#include "llvm/IR/DebugInfo.h"
57#include "llvm/IR/DebugInfoMetadata.h"
58#include "llvm/IR/DebugLoc.h"
59#include "llvm/IR/DerivedTypes.h"
60#include "llvm/IR/Function.h"
61#include "llvm/IR/GlobalValue.h"
62#include "llvm/IR/GlobalVariable.h"
63#include "llvm/IR/Instruction.h"
64#include "llvm/IR/LLVMContext.h"
65#include "llvm/IR/Module.h"
66#include "llvm/IR/Operator.h"
67#include "llvm/IR/Type.h"
68#include "llvm/IR/User.h"
69#include "llvm/MC/MCExpr.h"
70#include "llvm/MC/MCInst.h"
71#include "llvm/MC/MCInstrDesc.h"
72#include "llvm/MC/MCStreamer.h"
73#include "llvm/MC/MCSymbol.h"
74#include "llvm/Support/Casting.h"
75#include "llvm/Support/CommandLine.h"
76#include "llvm/Support/ErrorHandling.h"
77#include "llvm/Support/MachineValueType.h"
78#include "llvm/Support/Path.h"
79#include "llvm/Support/TargetRegistry.h"
80#include "llvm/Support/raw_ostream.h"
81#include "llvm/Target/TargetLoweringObjectFile.h"
82#include "llvm/Target/TargetMachine.h"
83#include "llvm/Transforms/Utils/UnrollLoop.h"
84#include <cassert>
85#include <cstdint>
86#include <cstring>
87#include <new>
88#include <string>
89#include <utility>
90#include <vector>
91
92using namespace llvm;
93
94#define DEPOTNAME "__local_depot"
95
96/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
97/// depends.
98static void
99DiscoverDependentGlobals(const Value *V,
100                         DenseSet<const GlobalVariable *> &Globals) {
101  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
102    Globals.insert(GV);
103  else {
104    if (const User *U = dyn_cast<User>(V)) {
105      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
106        DiscoverDependentGlobals(U->getOperand(i), Globals);
107      }
108    }
109  }
110}
111
112/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
113/// instances to be emitted, but only after any dependents have been added
114/// first.s
115static void
116VisitGlobalVariableForEmission(const GlobalVariable *GV,
117                               SmallVectorImpl<const GlobalVariable *> &Order,
118                               DenseSet<const GlobalVariable *> &Visited,
119                               DenseSet<const GlobalVariable *> &Visiting) {
120  // Have we already visited this one?
121  if (Visited.count(GV))
122    return;
123
124  // Do we have a circular dependency?
125  if (!Visiting.insert(GV).second)
126    report_fatal_error("Circular dependency found in global variable set");
127
128  // Make sure we visit all dependents first
129  DenseSet<const GlobalVariable *> Others;
130  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
131    DiscoverDependentGlobals(GV->getOperand(i), Others);
132
133  for (DenseSet<const GlobalVariable *>::iterator I = Others.begin(),
134                                                  E = Others.end();
135       I != E; ++I)
136    VisitGlobalVariableForEmission(*I, Order, Visited, Visiting);
137
138  // Now we can visit ourself
139  Order.push_back(GV);
140  Visited.insert(GV);
141  Visiting.erase(GV);
142}
143
144void NVPTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
145  MCInst Inst;
146  lowerToMCInst(MI, Inst);
147  EmitToStreamer(*OutStreamer, Inst);
148}
149
150// Handle symbol backtracking for targets that do not support image handles
151bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
152                                           unsigned OpNo, MCOperand &MCOp) {
153  const MachineOperand &MO = MI->getOperand(OpNo);
154  const MCInstrDesc &MCID = MI->getDesc();
155
156  if (MCID.TSFlags & NVPTXII::IsTexFlag) {
157    // This is a texture fetch, so operand 4 is a texref and operand 5 is
158    // a samplerref
159    if (OpNo == 4 && MO.isImm()) {
160      lowerImageHandleSymbol(MO.getImm(), MCOp);
161      return true;
162    }
163    if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
164      lowerImageHandleSymbol(MO.getImm(), MCOp);
165      return true;
166    }
167
168    return false;
169  } else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
170    unsigned VecSize =
171      1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
172
173    // For a surface load of vector size N, the Nth operand will be the surfref
174    if (OpNo == VecSize && MO.isImm()) {
175      lowerImageHandleSymbol(MO.getImm(), MCOp);
176      return true;
177    }
178
179    return false;
180  } else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
181    // This is a surface store, so operand 0 is a surfref
182    if (OpNo == 0 && MO.isImm()) {
183      lowerImageHandleSymbol(MO.getImm(), MCOp);
184      return true;
185    }
186
187    return false;
188  } else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
189    // This is a query, so operand 1 is a surfref/texref
190    if (OpNo == 1 && MO.isImm()) {
191      lowerImageHandleSymbol(MO.getImm(), MCOp);
192      return true;
193    }
194
195    return false;
196  }
197
198  return false;
199}
200
201void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
202  // Ewwww
203  LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
204  NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
205  const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
206  const char *Sym = MFI->getImageHandleSymbol(Index);
207  std::string *SymNamePtr =
208    nvTM.getManagedStrPool()->getManagedString(Sym);
209  MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(StringRef(*SymNamePtr)));
210}
211
212void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
213  OutMI.setOpcode(MI->getOpcode());
214  // Special: Do not mangle symbol operand of CALL_PROTOTYPE
215  if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
216    const MachineOperand &MO = MI->getOperand(0);
217    OutMI.addOperand(GetSymbolRef(
218      OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
219    return;
220  }
221
222  const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
223  for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
224    const MachineOperand &MO = MI->getOperand(i);
225
226    MCOperand MCOp;
227    if (!STI.hasImageHandles()) {
228      if (lowerImageHandleOperand(MI, i, MCOp)) {
229        OutMI.addOperand(MCOp);
230        continue;
231      }
232    }
233
234    if (lowerOperand(MO, MCOp))
235      OutMI.addOperand(MCOp);
236  }
237}
238
239bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
240                                   MCOperand &MCOp) {
241  switch (MO.getType()) {
242  default: llvm_unreachable("unknown operand type");
243  case MachineOperand::MO_Register:
244    MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
245    break;
246  case MachineOperand::MO_Immediate:
247    MCOp = MCOperand::createImm(MO.getImm());
248    break;
249  case MachineOperand::MO_MachineBasicBlock:
250    MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
251        MO.getMBB()->getSymbol(), OutContext));
252    break;
253  case MachineOperand::MO_ExternalSymbol:
254    MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
255    break;
256  case MachineOperand::MO_GlobalAddress:
257    MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
258    break;
259  case MachineOperand::MO_FPImmediate: {
260    const ConstantFP *Cnt = MO.getFPImm();
261    const APFloat &Val = Cnt->getValueAPF();
262
263    switch (Cnt->getType()->getTypeID()) {
264    default: report_fatal_error("Unsupported FP type"); break;
265    case Type::HalfTyID:
266      MCOp = MCOperand::createExpr(
267        NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
268      break;
269    case Type::FloatTyID:
270      MCOp = MCOperand::createExpr(
271        NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
272      break;
273    case Type::DoubleTyID:
274      MCOp = MCOperand::createExpr(
275        NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
276      break;
277    }
278    break;
279  }
280  }
281  return true;
282}
283
284unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
285  if (Register::isVirtualRegister(Reg)) {
286    const TargetRegisterClass *RC = MRI->getRegClass(Reg);
287
288    DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
289    unsigned RegNum = RegMap[Reg];
290
291    // Encode the register class in the upper 4 bits
292    // Must be kept in sync with NVPTXInstPrinter::printRegName
293    unsigned Ret = 0;
294    if (RC == &NVPTX::Int1RegsRegClass) {
295      Ret = (1 << 28);
296    } else if (RC == &NVPTX::Int16RegsRegClass) {
297      Ret = (2 << 28);
298    } else if (RC == &NVPTX::Int32RegsRegClass) {
299      Ret = (3 << 28);
300    } else if (RC == &NVPTX::Int64RegsRegClass) {
301      Ret = (4 << 28);
302    } else if (RC == &NVPTX::Float32RegsRegClass) {
303      Ret = (5 << 28);
304    } else if (RC == &NVPTX::Float64RegsRegClass) {
305      Ret = (6 << 28);
306    } else if (RC == &NVPTX::Float16RegsRegClass) {
307      Ret = (7 << 28);
308    } else if (RC == &NVPTX::Float16x2RegsRegClass) {
309      Ret = (8 << 28);
310    } else {
311      report_fatal_error("Bad register class");
312    }
313
314    // Insert the vreg number
315    Ret |= (RegNum & 0x0FFFFFFF);
316    return Ret;
317  } else {
318    // Some special-use registers are actually physical registers.
319    // Encode this as the register class ID of 0 and the real register ID.
320    return Reg & 0x0FFFFFFF;
321  }
322}
323
324MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
325  const MCExpr *Expr;
326  Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
327                                 OutContext);
328  return MCOperand::createExpr(Expr);
329}
330
331void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
332  const DataLayout &DL = getDataLayout();
333  const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
334  const TargetLowering *TLI = STI.getTargetLowering();
335
336  Type *Ty = F->getReturnType();
337
338  bool isABI = (STI.getSmVersion() >= 20);
339
340  if (Ty->getTypeID() == Type::VoidTyID)
341    return;
342
343  O << " (";
344
345  if (isABI) {
346    if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
347      unsigned size = 0;
348      if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
349        size = ITy->getBitWidth();
350      } else {
351        assert(Ty->isFloatingPointTy() && "Floating point type expected here");
352        size = Ty->getPrimitiveSizeInBits();
353      }
354      // PTX ABI requires all scalar return values to be at least 32
355      // bits in size.  fp16 normally uses .b16 as its storage type in
356      // PTX, so its size must be adjusted here, too.
357      if (size < 32)
358        size = 32;
359
360      O << ".param .b" << size << " func_retval0";
361    } else if (isa<PointerType>(Ty)) {
362      O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
363        << " func_retval0";
364    } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
365      unsigned totalsz = DL.getTypeAllocSize(Ty);
366      unsigned retAlignment = 0;
367      if (!getAlign(*F, 0, retAlignment))
368        retAlignment = DL.getABITypeAlignment(Ty);
369      O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz
370        << "]";
371    } else
372      llvm_unreachable("Unknown return type");
373  } else {
374    SmallVector<EVT, 16> vtparts;
375    ComputeValueVTs(*TLI, DL, Ty, vtparts);
376    unsigned idx = 0;
377    for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
378      unsigned elems = 1;
379      EVT elemtype = vtparts[i];
380      if (vtparts[i].isVector()) {
381        elems = vtparts[i].getVectorNumElements();
382        elemtype = vtparts[i].getVectorElementType();
383      }
384
385      for (unsigned j = 0, je = elems; j != je; ++j) {
386        unsigned sz = elemtype.getSizeInBits();
387        if (elemtype.isInteger() && (sz < 32))
388          sz = 32;
389        O << ".reg .b" << sz << " func_retval" << idx;
390        if (j < je - 1)
391          O << ", ";
392        ++idx;
393      }
394      if (i < e - 1)
395        O << ", ";
396    }
397  }
398  O << ") ";
399}
400
401void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
402                                        raw_ostream &O) {
403  const Function &F = MF.getFunction();
404  printReturnValStr(&F, O);
405}
406
407// Return true if MBB is the header of a loop marked with
408// llvm.loop.unroll.disable.
409// TODO: consider "#pragma unroll 1" which is equivalent to "#pragma nounroll".
410bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
411    const MachineBasicBlock &MBB) const {
412  MachineLoopInfo &LI = getAnalysis<MachineLoopInfo>();
413  // We insert .pragma "nounroll" only to the loop header.
414  if (!LI.isLoopHeader(&MBB))
415    return false;
416
417  // llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
418  // we iterate through each back edge of the loop with header MBB, and check
419  // whether its metadata contains llvm.loop.unroll.disable.
420  for (auto I = MBB.pred_begin(); I != MBB.pred_end(); ++I) {
421    const MachineBasicBlock *PMBB = *I;
422    if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
423      // Edges from other loops to MBB are not back edges.
424      continue;
425    }
426    if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
427      if (MDNode *LoopID =
428              PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
429        if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
430          return true;
431      }
432    }
433  }
434  return false;
435}
436
437void NVPTXAsmPrinter::EmitBasicBlockStart(const MachineBasicBlock &MBB) {
438  AsmPrinter::EmitBasicBlockStart(MBB);
439  if (isLoopHeaderOfNoUnroll(MBB))
440    OutStreamer->EmitRawText(StringRef("\t.pragma \"nounroll\";\n"));
441}
442
443void NVPTXAsmPrinter::EmitFunctionEntryLabel() {
444  SmallString<128> Str;
445  raw_svector_ostream O(Str);
446
447  if (!GlobalsEmitted) {
448    emitGlobals(*MF->getFunction().getParent());
449    GlobalsEmitted = true;
450  }
451
452  // Set up
453  MRI = &MF->getRegInfo();
454  F = &MF->getFunction();
455  emitLinkageDirective(F, O);
456  if (isKernelFunction(*F))
457    O << ".entry ";
458  else {
459    O << ".func ";
460    printReturnValStr(*MF, O);
461  }
462
463  CurrentFnSym->print(O, MAI);
464
465  emitFunctionParamList(*MF, O);
466
467  if (isKernelFunction(*F))
468    emitKernelFunctionDirectives(*F, O);
469
470  OutStreamer->EmitRawText(O.str());
471
472  VRegMapping.clear();
473  // Emit open brace for function body.
474  OutStreamer->EmitRawText(StringRef("{\n"));
475  setAndEmitFunctionVirtualRegisters(*MF);
476  // Emit initial .loc debug directive for correct relocation symbol data.
477  if (MMI && MMI->hasDebugInfo())
478    emitInitialRawDwarfLocDirective(*MF);
479}
480
481bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
482  bool Result = AsmPrinter::runOnMachineFunction(F);
483  // Emit closing brace for the body of function F.
484  // The closing brace must be emitted here because we need to emit additional
485  // debug labels/data after the last basic block.
486  // We need to emit the closing brace here because we don't have function that
487  // finished emission of the function body.
488  OutStreamer->EmitRawText(StringRef("}\n"));
489  return Result;
490}
491
492void NVPTXAsmPrinter::EmitFunctionBodyStart() {
493  SmallString<128> Str;
494  raw_svector_ostream O(Str);
495  emitDemotedVars(&MF->getFunction(), O);
496  OutStreamer->EmitRawText(O.str());
497}
498
499void NVPTXAsmPrinter::EmitFunctionBodyEnd() {
500  VRegMapping.clear();
501}
502
503const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
504    SmallString<128> Str;
505    raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
506    return OutContext.getOrCreateSymbol(Str);
507}
508
509void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
510  Register RegNo = MI->getOperand(0).getReg();
511  if (Register::isVirtualRegister(RegNo)) {
512    OutStreamer->AddComment(Twine("implicit-def: ") +
513                            getVirtualRegisterName(RegNo));
514  } else {
515    const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
516    OutStreamer->AddComment(Twine("implicit-def: ") +
517                            STI.getRegisterInfo()->getName(RegNo));
518  }
519  OutStreamer->AddBlankLine();
520}
521
522void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
523                                                   raw_ostream &O) const {
524  // If the NVVM IR has some of reqntid* specified, then output
525  // the reqntid directive, and set the unspecified ones to 1.
526  // If none of reqntid* is specified, don't output reqntid directive.
527  unsigned reqntidx, reqntidy, reqntidz;
528  bool specified = false;
529  if (!getReqNTIDx(F, reqntidx))
530    reqntidx = 1;
531  else
532    specified = true;
533  if (!getReqNTIDy(F, reqntidy))
534    reqntidy = 1;
535  else
536    specified = true;
537  if (!getReqNTIDz(F, reqntidz))
538    reqntidz = 1;
539  else
540    specified = true;
541
542  if (specified)
543    O << ".reqntid " << reqntidx << ", " << reqntidy << ", " << reqntidz
544      << "\n";
545
546  // If the NVVM IR has some of maxntid* specified, then output
547  // the maxntid directive, and set the unspecified ones to 1.
548  // If none of maxntid* is specified, don't output maxntid directive.
549  unsigned maxntidx, maxntidy, maxntidz;
550  specified = false;
551  if (!getMaxNTIDx(F, maxntidx))
552    maxntidx = 1;
553  else
554    specified = true;
555  if (!getMaxNTIDy(F, maxntidy))
556    maxntidy = 1;
557  else
558    specified = true;
559  if (!getMaxNTIDz(F, maxntidz))
560    maxntidz = 1;
561  else
562    specified = true;
563
564  if (specified)
565    O << ".maxntid " << maxntidx << ", " << maxntidy << ", " << maxntidz
566      << "\n";
567
568  unsigned mincta;
569  if (getMinCTASm(F, mincta))
570    O << ".minnctapersm " << mincta << "\n";
571
572  unsigned maxnreg;
573  if (getMaxNReg(F, maxnreg))
574    O << ".maxnreg " << maxnreg << "\n";
575}
576
577std::string
578NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
579  const TargetRegisterClass *RC = MRI->getRegClass(Reg);
580
581  std::string Name;
582  raw_string_ostream NameStr(Name);
583
584  VRegRCMap::const_iterator I = VRegMapping.find(RC);
585  assert(I != VRegMapping.end() && "Bad register class");
586  const DenseMap<unsigned, unsigned> &RegMap = I->second;
587
588  VRegMap::const_iterator VI = RegMap.find(Reg);
589  assert(VI != RegMap.end() && "Bad virtual register");
590  unsigned MappedVR = VI->second;
591
592  NameStr << getNVPTXRegClassStr(RC) << MappedVR;
593
594  NameStr.flush();
595  return Name;
596}
597
598void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
599                                          raw_ostream &O) {
600  O << getVirtualRegisterName(vr);
601}
602
603void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
604  emitLinkageDirective(F, O);
605  if (isKernelFunction(*F))
606    O << ".entry ";
607  else
608    O << ".func ";
609  printReturnValStr(F, O);
610  getSymbol(F)->print(O, MAI);
611  O << "\n";
612  emitFunctionParamList(F, O);
613  O << ";\n";
614}
615
616static bool usedInGlobalVarDef(const Constant *C) {
617  if (!C)
618    return false;
619
620  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
621    return GV->getName() != "llvm.used";
622  }
623
624  for (const User *U : C->users())
625    if (const Constant *C = dyn_cast<Constant>(U))
626      if (usedInGlobalVarDef(C))
627        return true;
628
629  return false;
630}
631
632static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
633  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
634    if (othergv->getName() == "llvm.used")
635      return true;
636  }
637
638  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
639    if (instr->getParent() && instr->getParent()->getParent()) {
640      const Function *curFunc = instr->getParent()->getParent();
641      if (oneFunc && (curFunc != oneFunc))
642        return false;
643      oneFunc = curFunc;
644      return true;
645    } else
646      return false;
647  }
648
649  for (const User *UU : U->users())
650    if (!usedInOneFunc(UU, oneFunc))
651      return false;
652
653  return true;
654}
655
656/* Find out if a global variable can be demoted to local scope.
657 * Currently, this is valid for CUDA shared variables, which have local
658 * scope and global lifetime. So the conditions to check are :
659 * 1. Is the global variable in shared address space?
660 * 2. Does it have internal linkage?
661 * 3. Is the global variable referenced only in one function?
662 */
663static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
664  if (!gv->hasInternalLinkage())
665    return false;
666  PointerType *Pty = gv->getType();
667  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
668    return false;
669
670  const Function *oneFunc = nullptr;
671
672  bool flag = usedInOneFunc(gv, oneFunc);
673  if (!flag)
674    return false;
675  if (!oneFunc)
676    return false;
677  f = oneFunc;
678  return true;
679}
680
681static bool useFuncSeen(const Constant *C,
682                        DenseMap<const Function *, bool> &seenMap) {
683  for (const User *U : C->users()) {
684    if (const Constant *cu = dyn_cast<Constant>(U)) {
685      if (useFuncSeen(cu, seenMap))
686        return true;
687    } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
688      const BasicBlock *bb = I->getParent();
689      if (!bb)
690        continue;
691      const Function *caller = bb->getParent();
692      if (!caller)
693        continue;
694      if (seenMap.find(caller) != seenMap.end())
695        return true;
696    }
697  }
698  return false;
699}
700
701void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
702  DenseMap<const Function *, bool> seenMap;
703  for (Module::const_iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) {
704    const Function *F = &*FI;
705
706    if (F->getAttributes().hasFnAttribute("nvptx-libcall-callee")) {
707      emitDeclaration(F, O);
708      continue;
709    }
710
711    if (F->isDeclaration()) {
712      if (F->use_empty())
713        continue;
714      if (F->getIntrinsicID())
715        continue;
716      emitDeclaration(F, O);
717      continue;
718    }
719    for (const User *U : F->users()) {
720      if (const Constant *C = dyn_cast<Constant>(U)) {
721        if (usedInGlobalVarDef(C)) {
722          // The use is in the initialization of a global variable
723          // that is a function pointer, so print a declaration
724          // for the original function
725          emitDeclaration(F, O);
726          break;
727        }
728        // Emit a declaration of this function if the function that
729        // uses this constant expr has already been seen.
730        if (useFuncSeen(C, seenMap)) {
731          emitDeclaration(F, O);
732          break;
733        }
734      }
735
736      if (!isa<Instruction>(U))
737        continue;
738      const Instruction *instr = cast<Instruction>(U);
739      const BasicBlock *bb = instr->getParent();
740      if (!bb)
741        continue;
742      const Function *caller = bb->getParent();
743      if (!caller)
744        continue;
745
746      // If a caller has already been seen, then the caller is
747      // appearing in the module before the callee. so print out
748      // a declaration for the callee.
749      if (seenMap.find(caller) != seenMap.end()) {
750        emitDeclaration(F, O);
751        break;
752      }
753    }
754    seenMap[F] = true;
755  }
756}
757
758static bool isEmptyXXStructor(GlobalVariable *GV) {
759  if (!GV) return true;
760  const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
761  if (!InitList) return true;  // Not an array; we don't know how to parse.
762  return InitList->getNumOperands() == 0;
763}
764
765bool NVPTXAsmPrinter::doInitialization(Module &M) {
766  // Construct a default subtarget off of the TargetMachine defaults. The
767  // rest of NVPTX isn't friendly to change subtargets per function and
768  // so the default TargetMachine will have all of the options.
769  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
770  const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
771
772  if (M.alias_size()) {
773    report_fatal_error("Module has aliases, which NVPTX does not support.");
774    return true; // error
775  }
776  if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors"))) {
777    report_fatal_error(
778        "Module has a nontrivial global ctor, which NVPTX does not support.");
779    return true;  // error
780  }
781  if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors"))) {
782    report_fatal_error(
783        "Module has a nontrivial global dtor, which NVPTX does not support.");
784    return true;  // error
785  }
786
787  SmallString<128> Str1;
788  raw_svector_ostream OS1(Str1);
789
790  // We need to call the parent's one explicitly.
791  bool Result = AsmPrinter::doInitialization(M);
792
793  // Emit header before any dwarf directives are emitted below.
794  emitHeader(M, OS1, *STI);
795  OutStreamer->EmitRawText(OS1.str());
796
797  // Emit module-level inline asm if it exists.
798  if (!M.getModuleInlineAsm().empty()) {
799    OutStreamer->AddComment("Start of file scope inline assembly");
800    OutStreamer->AddBlankLine();
801    OutStreamer->EmitRawText(StringRef(M.getModuleInlineAsm()));
802    OutStreamer->AddBlankLine();
803    OutStreamer->AddComment("End of file scope inline assembly");
804    OutStreamer->AddBlankLine();
805  }
806
807  GlobalsEmitted = false;
808
809  return Result;
810}
811
812void NVPTXAsmPrinter::emitGlobals(const Module &M) {
813  SmallString<128> Str2;
814  raw_svector_ostream OS2(Str2);
815
816  emitDeclarations(M, OS2);
817
818  // As ptxas does not support forward references of globals, we need to first
819  // sort the list of module-level globals in def-use order. We visit each
820  // global variable in order, and ensure that we emit it *after* its dependent
821  // globals. We use a little extra memory maintaining both a set and a list to
822  // have fast searches while maintaining a strict ordering.
823  SmallVector<const GlobalVariable *, 8> Globals;
824  DenseSet<const GlobalVariable *> GVVisited;
825  DenseSet<const GlobalVariable *> GVVisiting;
826
827  // Visit each global variable, in order
828  for (const GlobalVariable &I : M.globals())
829    VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
830
831  assert(GVVisited.size() == M.getGlobalList().size() &&
832         "Missed a global variable");
833  assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
834
835  // Print out module-level global variables in proper order
836  for (unsigned i = 0, e = Globals.size(); i != e; ++i)
837    printModuleLevelGV(Globals[i], OS2);
838
839  OS2 << '\n';
840
841  OutStreamer->EmitRawText(OS2.str());
842}
843
844void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
845                                 const NVPTXSubtarget &STI) {
846  O << "//\n";
847  O << "// Generated by LLVM NVPTX Back-End\n";
848  O << "//\n";
849  O << "\n";
850
851  unsigned PTXVersion = STI.getPTXVersion();
852  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
853
854  O << ".target ";
855  O << STI.getTargetName();
856
857  const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
858  if (NTM.getDrvInterface() == NVPTX::NVCL)
859    O << ", texmode_independent";
860
861  bool HasFullDebugInfo = false;
862  for (DICompileUnit *CU : M.debug_compile_units()) {
863    switch(CU->getEmissionKind()) {
864    case DICompileUnit::NoDebug:
865    case DICompileUnit::DebugDirectivesOnly:
866      break;
867    case DICompileUnit::LineTablesOnly:
868    case DICompileUnit::FullDebug:
869      HasFullDebugInfo = true;
870      break;
871    }
872    if (HasFullDebugInfo)
873      break;
874  }
875  if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
876    O << ", debug";
877
878  O << "\n";
879
880  O << ".address_size ";
881  if (NTM.is64Bit())
882    O << "64";
883  else
884    O << "32";
885  O << "\n";
886
887  O << "\n";
888}
889
890bool NVPTXAsmPrinter::doFinalization(Module &M) {
891  bool HasDebugInfo = MMI && MMI->hasDebugInfo();
892
893  // If we did not emit any functions, then the global declarations have not
894  // yet been emitted.
895  if (!GlobalsEmitted) {
896    emitGlobals(M);
897    GlobalsEmitted = true;
898  }
899
900  // XXX Temproarily remove global variables so that doFinalization() will not
901  // emit them again (global variables are emitted at beginning).
902
903  Module::GlobalListType &global_list = M.getGlobalList();
904  int i, n = global_list.size();
905  GlobalVariable **gv_array = new GlobalVariable *[n];
906
907  // first, back-up GlobalVariable in gv_array
908  i = 0;
909  for (Module::global_iterator I = global_list.begin(), E = global_list.end();
910       I != E; ++I)
911    gv_array[i++] = &*I;
912
913  // second, empty global_list
914  while (!global_list.empty())
915    global_list.remove(global_list.begin());
916
917  // call doFinalization
918  bool ret = AsmPrinter::doFinalization(M);
919
920  // now we restore global variables
921  for (i = 0; i < n; i++)
922    global_list.insert(global_list.end(), gv_array[i]);
923
924  clearAnnotationCache(&M);
925
926  delete[] gv_array;
927  // Close the last emitted section
928  if (HasDebugInfo) {
929    static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
930        ->closeLastSection();
931    // Emit empty .debug_loc section for better support of the empty files.
932    OutStreamer->EmitRawText("\t.section\t.debug_loc\t{\t}");
933  }
934
935  // Output last DWARF .file directives, if any.
936  static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer())
937      ->outputDwarfFileDirectives();
938
939  return ret;
940
941  //bool Result = AsmPrinter::doFinalization(M);
942  // Instead of calling the parents doFinalization, we may
943  // clone parents doFinalization and customize here.
944  // Currently, we if NVISA out the EmitGlobals() in
945  // parent's doFinalization, which is too intrusive.
946  //
947  // Same for the doInitialization.
948  //return Result;
949}
950
951// This function emits appropriate linkage directives for
952// functions and global variables.
953//
954// extern function declaration            -> .extern
955// extern function definition             -> .visible
956// external global variable with init     -> .visible
957// external without init                  -> .extern
958// appending                              -> not allowed, assert.
959// for any linkage other than
960// internal, private, linker_private,
961// linker_private_weak, linker_private_weak_def_auto,
962// we emit                                -> .weak.
963
964void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
965                                           raw_ostream &O) {
966  if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
967    if (V->hasExternalLinkage()) {
968      if (isa<GlobalVariable>(V)) {
969        const GlobalVariable *GVar = cast<GlobalVariable>(V);
970        if (GVar) {
971          if (GVar->hasInitializer())
972            O << ".visible ";
973          else
974            O << ".extern ";
975        }
976      } else if (V->isDeclaration())
977        O << ".extern ";
978      else
979        O << ".visible ";
980    } else if (V->hasAppendingLinkage()) {
981      std::string msg;
982      msg.append("Error: ");
983      msg.append("Symbol ");
984      if (V->hasName())
985        msg.append(V->getName());
986      msg.append("has unsupported appending linkage type");
987      llvm_unreachable(msg.c_str());
988    } else if (!V->hasInternalLinkage() &&
989               !V->hasPrivateLinkage()) {
990      O << ".weak ";
991    }
992  }
993}
994
995void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
996                                         raw_ostream &O,
997                                         bool processDemoted) {
998  // Skip meta data
999  if (GVar->hasSection()) {
1000    if (GVar->getSection() == "llvm.metadata")
1001      return;
1002  }
1003
1004  // Skip LLVM intrinsic global variables
1005  if (GVar->getName().startswith("llvm.") ||
1006      GVar->getName().startswith("nvvm."))
1007    return;
1008
1009  const DataLayout &DL = getDataLayout();
1010
1011  // GlobalVariables are always constant pointers themselves.
1012  PointerType *PTy = GVar->getType();
1013  Type *ETy = GVar->getValueType();
1014
1015  if (GVar->hasExternalLinkage()) {
1016    if (GVar->hasInitializer())
1017      O << ".visible ";
1018    else
1019      O << ".extern ";
1020  } else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1021             GVar->hasAvailableExternallyLinkage() ||
1022             GVar->hasCommonLinkage()) {
1023    O << ".weak ";
1024  }
1025
1026  if (isTexture(*GVar)) {
1027    O << ".global .texref " << getTextureName(*GVar) << ";\n";
1028    return;
1029  }
1030
1031  if (isSurface(*GVar)) {
1032    O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1033    return;
1034  }
1035
1036  if (GVar->isDeclaration()) {
1037    // (extern) declarations, no definition or initializer
1038    // Currently the only known declaration is for an automatic __local
1039    // (.shared) promoted to global.
1040    emitPTXGlobalVariable(GVar, O);
1041    O << ";\n";
1042    return;
1043  }
1044
1045  if (isSampler(*GVar)) {
1046    O << ".global .samplerref " << getSamplerName(*GVar);
1047
1048    const Constant *Initializer = nullptr;
1049    if (GVar->hasInitializer())
1050      Initializer = GVar->getInitializer();
1051    const ConstantInt *CI = nullptr;
1052    if (Initializer)
1053      CI = dyn_cast<ConstantInt>(Initializer);
1054    if (CI) {
1055      unsigned sample = CI->getZExtValue();
1056
1057      O << " = { ";
1058
1059      for (int i = 0,
1060               addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1061           i < 3; i++) {
1062        O << "addr_mode_" << i << " = ";
1063        switch (addr) {
1064        case 0:
1065          O << "wrap";
1066          break;
1067        case 1:
1068          O << "clamp_to_border";
1069          break;
1070        case 2:
1071          O << "clamp_to_edge";
1072          break;
1073        case 3:
1074          O << "wrap";
1075          break;
1076        case 4:
1077          O << "mirror";
1078          break;
1079        }
1080        O << ", ";
1081      }
1082      O << "filter_mode = ";
1083      switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1084      case 0:
1085        O << "nearest";
1086        break;
1087      case 1:
1088        O << "linear";
1089        break;
1090      case 2:
1091        llvm_unreachable("Anisotropic filtering is not supported");
1092      default:
1093        O << "nearest";
1094        break;
1095      }
1096      if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1097        O << ", force_unnormalized_coords = 1";
1098      }
1099      O << " }";
1100    }
1101
1102    O << ";\n";
1103    return;
1104  }
1105
1106  if (GVar->hasPrivateLinkage()) {
1107    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1108      return;
1109
1110    // FIXME - need better way (e.g. Metadata) to avoid generating this global
1111    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1112      return;
1113    if (GVar->use_empty())
1114      return;
1115  }
1116
1117  const Function *demotedFunc = nullptr;
1118  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1119    O << "// " << GVar->getName() << " has been demoted\n";
1120    if (localDecls.find(demotedFunc) != localDecls.end())
1121      localDecls[demotedFunc].push_back(GVar);
1122    else {
1123      std::vector<const GlobalVariable *> temp;
1124      temp.push_back(GVar);
1125      localDecls[demotedFunc] = temp;
1126    }
1127    return;
1128  }
1129
1130  O << ".";
1131  emitPTXAddressSpace(PTy->getAddressSpace(), O);
1132
1133  if (isManaged(*GVar)) {
1134    O << " .attribute(.managed)";
1135  }
1136
1137  if (GVar->getAlignment() == 0)
1138    O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1139  else
1140    O << " .align " << GVar->getAlignment();
1141
1142  if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1143      (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1144    O << " .";
1145    // Special case: ABI requires that we use .u8 for predicates
1146    if (ETy->isIntegerTy(1))
1147      O << "u8";
1148    else
1149      O << getPTXFundamentalTypeStr(ETy, false);
1150    O << " ";
1151    getSymbol(GVar)->print(O, MAI);
1152
1153    // Ptx allows variable initilization only for constant and global state
1154    // spaces.
1155    if (GVar->hasInitializer()) {
1156      if ((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1157          (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1158        const Constant *Initializer = GVar->getInitializer();
1159        // 'undef' is treated as there is no value specified.
1160        if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1161          O << " = ";
1162          printScalarConstant(Initializer, O);
1163        }
1164      } else {
1165        // The frontend adds zero-initializer to device and constant variables
1166        // that don't have an initial value, and UndefValue to shared
1167        // variables, so skip warning for this case.
1168        if (!GVar->getInitializer()->isNullValue() &&
1169            !isa<UndefValue>(GVar->getInitializer())) {
1170          report_fatal_error("initial value of '" + GVar->getName() +
1171                             "' is not allowed in addrspace(" +
1172                             Twine(PTy->getAddressSpace()) + ")");
1173        }
1174      }
1175    }
1176  } else {
1177    unsigned int ElementSize = 0;
1178
1179    // Although PTX has direct support for struct type and array type and
1180    // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1181    // targets that support these high level field accesses. Structs, arrays
1182    // and vectors are lowered into arrays of bytes.
1183    switch (ETy->getTypeID()) {
1184    case Type::IntegerTyID: // Integers larger than 64 bits
1185    case Type::StructTyID:
1186    case Type::ArrayTyID:
1187    case Type::VectorTyID:
1188      ElementSize = DL.getTypeStoreSize(ETy);
1189      // Ptx allows variable initilization only for constant and
1190      // global state spaces.
1191      if (((PTy->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1192           (PTy->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1193          GVar->hasInitializer()) {
1194        const Constant *Initializer = GVar->getInitializer();
1195        if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1196          AggBuffer aggBuffer(ElementSize, O, *this);
1197          bufferAggregateConstant(Initializer, &aggBuffer);
1198          if (aggBuffer.numSymbols) {
1199            if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit()) {
1200              O << " .u64 ";
1201              getSymbol(GVar)->print(O, MAI);
1202              O << "[";
1203              O << ElementSize / 8;
1204            } else {
1205              O << " .u32 ";
1206              getSymbol(GVar)->print(O, MAI);
1207              O << "[";
1208              O << ElementSize / 4;
1209            }
1210            O << "]";
1211          } else {
1212            O << " .b8 ";
1213            getSymbol(GVar)->print(O, MAI);
1214            O << "[";
1215            O << ElementSize;
1216            O << "]";
1217          }
1218          O << " = {";
1219          aggBuffer.print();
1220          O << "}";
1221        } else {
1222          O << " .b8 ";
1223          getSymbol(GVar)->print(O, MAI);
1224          if (ElementSize) {
1225            O << "[";
1226            O << ElementSize;
1227            O << "]";
1228          }
1229        }
1230      } else {
1231        O << " .b8 ";
1232        getSymbol(GVar)->print(O, MAI);
1233        if (ElementSize) {
1234          O << "[";
1235          O << ElementSize;
1236          O << "]";
1237        }
1238      }
1239      break;
1240    default:
1241      llvm_unreachable("type not supported yet");
1242    }
1243  }
1244  O << ";\n";
1245}
1246
1247void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1248  if (localDecls.find(f) == localDecls.end())
1249    return;
1250
1251  std::vector<const GlobalVariable *> &gvars = localDecls[f];
1252
1253  for (unsigned i = 0, e = gvars.size(); i != e; ++i) {
1254    O << "\t// demoted variable\n\t";
1255    printModuleLevelGV(gvars[i], O, true);
1256  }
1257}
1258
1259void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1260                                          raw_ostream &O) const {
1261  switch (AddressSpace) {
1262  case ADDRESS_SPACE_LOCAL:
1263    O << "local";
1264    break;
1265  case ADDRESS_SPACE_GLOBAL:
1266    O << "global";
1267    break;
1268  case ADDRESS_SPACE_CONST:
1269    O << "const";
1270    break;
1271  case ADDRESS_SPACE_SHARED:
1272    O << "shared";
1273    break;
1274  default:
1275    report_fatal_error("Bad address space found while emitting PTX: " +
1276                       llvm::Twine(AddressSpace));
1277    break;
1278  }
1279}
1280
1281std::string
1282NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1283  switch (Ty->getTypeID()) {
1284  default:
1285    llvm_unreachable("unexpected type");
1286    break;
1287  case Type::IntegerTyID: {
1288    unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1289    if (NumBits == 1)
1290      return "pred";
1291    else if (NumBits <= 64) {
1292      std::string name = "u";
1293      return name + utostr(NumBits);
1294    } else {
1295      llvm_unreachable("Integer too large");
1296      break;
1297    }
1298    break;
1299  }
1300  case Type::HalfTyID:
1301    // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
1302    return "b16";
1303  case Type::FloatTyID:
1304    return "f32";
1305  case Type::DoubleTyID:
1306    return "f64";
1307  case Type::PointerTyID:
1308    if (static_cast<const NVPTXTargetMachine &>(TM).is64Bit())
1309      if (useB4PTR)
1310        return "b64";
1311      else
1312        return "u64";
1313    else if (useB4PTR)
1314      return "b32";
1315    else
1316      return "u32";
1317  }
1318  llvm_unreachable("unexpected type");
1319  return nullptr;
1320}
1321
1322void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1323                                            raw_ostream &O) {
1324  const DataLayout &DL = getDataLayout();
1325
1326  // GlobalVariables are always constant pointers themselves.
1327  Type *ETy = GVar->getValueType();
1328
1329  O << ".";
1330  emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1331  if (GVar->getAlignment() == 0)
1332    O << " .align " << (int)DL.getPrefTypeAlignment(ETy);
1333  else
1334    O << " .align " << GVar->getAlignment();
1335
1336  // Special case for i128
1337  if (ETy->isIntegerTy(128)) {
1338    O << " .b8 ";
1339    getSymbol(GVar)->print(O, MAI);
1340    O << "[16]";
1341    return;
1342  }
1343
1344  if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1345    O << " .";
1346    O << getPTXFundamentalTypeStr(ETy);
1347    O << " ";
1348    getSymbol(GVar)->print(O, MAI);
1349    return;
1350  }
1351
1352  int64_t ElementSize = 0;
1353
1354  // Although PTX has direct support for struct type and array type and LLVM IR
1355  // is very similar to PTX, the LLVM CodeGen does not support for targets that
1356  // support these high level field accesses. Structs and arrays are lowered
1357  // into arrays of bytes.
1358  switch (ETy->getTypeID()) {
1359  case Type::StructTyID:
1360  case Type::ArrayTyID:
1361  case Type::VectorTyID:
1362    ElementSize = DL.getTypeStoreSize(ETy);
1363    O << " .b8 ";
1364    getSymbol(GVar)->print(O, MAI);
1365    O << "[";
1366    if (ElementSize) {
1367      O << ElementSize;
1368    }
1369    O << "]";
1370    break;
1371  default:
1372    llvm_unreachable("type not supported yet");
1373  }
1374}
1375
1376static unsigned int getOpenCLAlignment(const DataLayout &DL, Type *Ty) {
1377  if (Ty->isSingleValueType())
1378    return DL.getPrefTypeAlignment(Ty);
1379
1380  auto *ATy = dyn_cast<ArrayType>(Ty);
1381  if (ATy)
1382    return getOpenCLAlignment(DL, ATy->getElementType());
1383
1384  auto *STy = dyn_cast<StructType>(Ty);
1385  if (STy) {
1386    unsigned int alignStruct = 1;
1387    // Go through each element of the struct and find the
1388    // largest alignment.
1389    for (unsigned i = 0, e = STy->getNumElements(); i != e; i++) {
1390      Type *ETy = STy->getElementType(i);
1391      unsigned int align = getOpenCLAlignment(DL, ETy);
1392      if (align > alignStruct)
1393        alignStruct = align;
1394    }
1395    return alignStruct;
1396  }
1397
1398  auto *FTy = dyn_cast<FunctionType>(Ty);
1399  if (FTy)
1400    return DL.getPointerPrefAlignment().value();
1401  return DL.getPrefTypeAlignment(Ty);
1402}
1403
1404void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
1405                                     int paramIndex, raw_ostream &O) {
1406  getSymbol(I->getParent())->print(O, MAI);
1407  O << "_param_" << paramIndex;
1408}
1409
1410void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1411  const DataLayout &DL = getDataLayout();
1412  const AttributeList &PAL = F->getAttributes();
1413  const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1414  const TargetLowering *TLI = STI.getTargetLowering();
1415  Function::const_arg_iterator I, E;
1416  unsigned paramIndex = 0;
1417  bool first = true;
1418  bool isKernelFunc = isKernelFunction(*F);
1419  bool isABI = (STI.getSmVersion() >= 20);
1420  bool hasImageHandles = STI.hasImageHandles();
1421  MVT thePointerTy = TLI->getPointerTy(DL);
1422
1423  if (F->arg_empty()) {
1424    O << "()\n";
1425    return;
1426  }
1427
1428  O << "(\n";
1429
1430  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1431    Type *Ty = I->getType();
1432
1433    if (!first)
1434      O << ",\n";
1435
1436    first = false;
1437
1438    // Handle image/sampler parameters
1439    if (isKernelFunction(*F)) {
1440      if (isSampler(*I) || isImage(*I)) {
1441        if (isImage(*I)) {
1442          std::string sname = I->getName();
1443          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1444            if (hasImageHandles)
1445              O << "\t.param .u64 .ptr .surfref ";
1446            else
1447              O << "\t.param .surfref ";
1448            CurrentFnSym->print(O, MAI);
1449            O << "_param_" << paramIndex;
1450          }
1451          else { // Default image is read_only
1452            if (hasImageHandles)
1453              O << "\t.param .u64 .ptr .texref ";
1454            else
1455              O << "\t.param .texref ";
1456            CurrentFnSym->print(O, MAI);
1457            O << "_param_" << paramIndex;
1458          }
1459        } else {
1460          if (hasImageHandles)
1461            O << "\t.param .u64 .ptr .samplerref ";
1462          else
1463            O << "\t.param .samplerref ";
1464          CurrentFnSym->print(O, MAI);
1465          O << "_param_" << paramIndex;
1466        }
1467        continue;
1468      }
1469    }
1470
1471    if (!PAL.hasParamAttribute(paramIndex, Attribute::ByVal)) {
1472      if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
1473        // Just print .param .align <a> .b8 .param[size];
1474        // <a> = PAL.getparamalignment
1475        // size = typeallocsize of element type
1476        const Align align = DL.getValueOrABITypeAlignment(
1477            PAL.getParamAlignment(paramIndex), Ty);
1478
1479        unsigned sz = DL.getTypeAllocSize(Ty);
1480        O << "\t.param .align " << align.value() << " .b8 ";
1481        printParamName(I, paramIndex, O);
1482        O << "[" << sz << "]";
1483
1484        continue;
1485      }
1486      // Just a scalar
1487      auto *PTy = dyn_cast<PointerType>(Ty);
1488      if (isKernelFunc) {
1489        if (PTy) {
1490          // Special handling for pointer arguments to kernel
1491          O << "\t.param .u" << thePointerTy.getSizeInBits() << " ";
1492
1493          if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1494              NVPTX::CUDA) {
1495            Type *ETy = PTy->getElementType();
1496            int addrSpace = PTy->getAddressSpace();
1497            switch (addrSpace) {
1498            default:
1499              O << ".ptr ";
1500              break;
1501            case ADDRESS_SPACE_CONST:
1502              O << ".ptr .const ";
1503              break;
1504            case ADDRESS_SPACE_SHARED:
1505              O << ".ptr .shared ";
1506              break;
1507            case ADDRESS_SPACE_GLOBAL:
1508              O << ".ptr .global ";
1509              break;
1510            }
1511            O << ".align " << (int)getOpenCLAlignment(DL, ETy) << " ";
1512          }
1513          printParamName(I, paramIndex, O);
1514          continue;
1515        }
1516
1517        // non-pointer scalar to kernel func
1518        O << "\t.param .";
1519        // Special case: predicate operands become .u8 types
1520        if (Ty->isIntegerTy(1))
1521          O << "u8";
1522        else
1523          O << getPTXFundamentalTypeStr(Ty);
1524        O << " ";
1525        printParamName(I, paramIndex, O);
1526        continue;
1527      }
1528      // Non-kernel function, just print .param .b<size> for ABI
1529      // and .reg .b<size> for non-ABI
1530      unsigned sz = 0;
1531      if (isa<IntegerType>(Ty)) {
1532        sz = cast<IntegerType>(Ty)->getBitWidth();
1533        if (sz < 32)
1534          sz = 32;
1535      } else if (isa<PointerType>(Ty))
1536        sz = thePointerTy.getSizeInBits();
1537      else if (Ty->isHalfTy())
1538        // PTX ABI requires all scalar parameters to be at least 32
1539        // bits in size.  fp16 normally uses .b16 as its storage type
1540        // in PTX, so its size must be adjusted here, too.
1541        sz = 32;
1542      else
1543        sz = Ty->getPrimitiveSizeInBits();
1544      if (isABI)
1545        O << "\t.param .b" << sz << " ";
1546      else
1547        O << "\t.reg .b" << sz << " ";
1548      printParamName(I, paramIndex, O);
1549      continue;
1550    }
1551
1552    // param has byVal attribute. So should be a pointer
1553    auto *PTy = dyn_cast<PointerType>(Ty);
1554    assert(PTy && "Param with byval attribute should be a pointer type");
1555    Type *ETy = PTy->getElementType();
1556
1557    if (isABI || isKernelFunc) {
1558      // Just print .param .align <a> .b8 .param[size];
1559      // <a> = PAL.getparamalignment
1560      // size = typeallocsize of element type
1561      Align align =
1562          DL.getValueOrABITypeAlignment(PAL.getParamAlignment(paramIndex), ETy);
1563      // Work around a bug in ptxas. When PTX code takes address of
1564      // byval parameter with alignment < 4, ptxas generates code to
1565      // spill argument into memory. Alas on sm_50+ ptxas generates
1566      // SASS code that fails with misaligned access. To work around
1567      // the problem, make sure that we align byval parameters by at
1568      // least 4. Matching change must be made in LowerCall() where we
1569      // prepare parameters for the call.
1570      //
1571      // TODO: this will need to be undone when we get to support multi-TU
1572      // device-side compilation as it breaks ABI compatibility with nvcc.
1573      // Hopefully ptxas bug is fixed by then.
1574      if (!isKernelFunc && align < Align(4))
1575        align = Align(4);
1576      unsigned sz = DL.getTypeAllocSize(ETy);
1577      O << "\t.param .align " << align.value() << " .b8 ";
1578      printParamName(I, paramIndex, O);
1579      O << "[" << sz << "]";
1580      continue;
1581    } else {
1582      // Split the ETy into constituent parts and
1583      // print .param .b<size> <name> for each part.
1584      // Further, if a part is vector, print the above for
1585      // each vector element.
1586      SmallVector<EVT, 16> vtparts;
1587      ComputeValueVTs(*TLI, DL, ETy, vtparts);
1588      for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1589        unsigned elems = 1;
1590        EVT elemtype = vtparts[i];
1591        if (vtparts[i].isVector()) {
1592          elems = vtparts[i].getVectorNumElements();
1593          elemtype = vtparts[i].getVectorElementType();
1594        }
1595
1596        for (unsigned j = 0, je = elems; j != je; ++j) {
1597          unsigned sz = elemtype.getSizeInBits();
1598          if (elemtype.isInteger() && (sz < 32))
1599            sz = 32;
1600          O << "\t.reg .b" << sz << " ";
1601          printParamName(I, paramIndex, O);
1602          if (j < je - 1)
1603            O << ",\n";
1604          ++paramIndex;
1605        }
1606        if (i < e - 1)
1607          O << ",\n";
1608      }
1609      --paramIndex;
1610      continue;
1611    }
1612  }
1613
1614  O << "\n)\n";
1615}
1616
1617void NVPTXAsmPrinter::emitFunctionParamList(const MachineFunction &MF,
1618                                            raw_ostream &O) {
1619  const Function &F = MF.getFunction();
1620  emitFunctionParamList(&F, O);
1621}
1622
1623void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1624    const MachineFunction &MF) {
1625  SmallString<128> Str;
1626  raw_svector_ostream O(Str);
1627
1628  // Map the global virtual register number to a register class specific
1629  // virtual register number starting from 1 with that class.
1630  const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1631  //unsigned numRegClasses = TRI->getNumRegClasses();
1632
1633  // Emit the Fake Stack Object
1634  const MachineFrameInfo &MFI = MF.getFrameInfo();
1635  int NumBytes = (int) MFI.getStackSize();
1636  if (NumBytes) {
1637    O << "\t.local .align " << MFI.getMaxAlignment() << " .b8 \t" << DEPOTNAME
1638      << getFunctionNumber() << "[" << NumBytes << "];\n";
1639    if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1640      O << "\t.reg .b64 \t%SP;\n";
1641      O << "\t.reg .b64 \t%SPL;\n";
1642    } else {
1643      O << "\t.reg .b32 \t%SP;\n";
1644      O << "\t.reg .b32 \t%SPL;\n";
1645    }
1646  }
1647
1648  // Go through all virtual registers to establish the mapping between the
1649  // global virtual
1650  // register number and the per class virtual register number.
1651  // We use the per class virtual register number in the ptx output.
1652  unsigned int numVRs = MRI->getNumVirtRegs();
1653  for (unsigned i = 0; i < numVRs; i++) {
1654    unsigned int vr = Register::index2VirtReg(i);
1655    const TargetRegisterClass *RC = MRI->getRegClass(vr);
1656    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1657    int n = regmap.size();
1658    regmap.insert(std::make_pair(vr, n + 1));
1659  }
1660
1661  // Emit register declarations
1662  // @TODO: Extract out the real register usage
1663  // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1664  // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1665  // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1666  // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1667  // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1668  // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1669  // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1670
1671  // Emit declaration of the virtual registers or 'physical' registers for
1672  // each register class
1673  for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1674    const TargetRegisterClass *RC = TRI->getRegClass(i);
1675    DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1676    std::string rcname = getNVPTXRegClassName(RC);
1677    std::string rcStr = getNVPTXRegClassStr(RC);
1678    int n = regmap.size();
1679
1680    // Only declare those registers that may be used.
1681    if (n) {
1682       O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1683         << ">;\n";
1684    }
1685  }
1686
1687  OutStreamer->EmitRawText(O.str());
1688}
1689
1690void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1691  APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1692  bool ignored;
1693  unsigned int numHex;
1694  const char *lead;
1695
1696  if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1697    numHex = 8;
1698    lead = "0f";
1699    APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1700  } else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1701    numHex = 16;
1702    lead = "0d";
1703    APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1704  } else
1705    llvm_unreachable("unsupported fp type");
1706
1707  APInt API = APF.bitcastToAPInt();
1708  O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1709}
1710
1711void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1712  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1713    O << CI->getValue();
1714    return;
1715  }
1716  if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1717    printFPConstant(CFP, O);
1718    return;
1719  }
1720  if (isa<ConstantPointerNull>(CPV)) {
1721    O << "0";
1722    return;
1723  }
1724  if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1725    bool IsNonGenericPointer = false;
1726    if (GVar->getType()->getAddressSpace() != 0) {
1727      IsNonGenericPointer = true;
1728    }
1729    if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1730      O << "generic(";
1731      getSymbol(GVar)->print(O, MAI);
1732      O << ")";
1733    } else {
1734      getSymbol(GVar)->print(O, MAI);
1735    }
1736    return;
1737  }
1738  if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1739    const Value *v = Cexpr->stripPointerCasts();
1740    PointerType *PTy = dyn_cast<PointerType>(Cexpr->getType());
1741    bool IsNonGenericPointer = false;
1742    if (PTy && PTy->getAddressSpace() != 0) {
1743      IsNonGenericPointer = true;
1744    }
1745    if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1746      if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) {
1747        O << "generic(";
1748        getSymbol(GVar)->print(O, MAI);
1749        O << ")";
1750      } else {
1751        getSymbol(GVar)->print(O, MAI);
1752      }
1753      return;
1754    } else {
1755      lowerConstant(CPV)->print(O, MAI);
1756      return;
1757    }
1758  }
1759  llvm_unreachable("Not scalar type found in printScalarConstant()");
1760}
1761
1762// These utility functions assure we get the right sequence of bytes for a given
1763// type even for big-endian machines
1764template <typename T> static void ConvertIntToBytes(unsigned char *p, T val) {
1765  int64_t vp = (int64_t)val;
1766  for (unsigned i = 0; i < sizeof(T); ++i) {
1767    p[i] = (unsigned char)vp;
1768    vp >>= 8;
1769  }
1770}
1771static void ConvertFloatToBytes(unsigned char *p, float val) {
1772  int32_t *vp = (int32_t *)&val;
1773  for (unsigned i = 0; i < sizeof(int32_t); ++i) {
1774    p[i] = (unsigned char)*vp;
1775    *vp >>= 8;
1776  }
1777}
1778static void ConvertDoubleToBytes(unsigned char *p, double val) {
1779  int64_t *vp = (int64_t *)&val;
1780  for (unsigned i = 0; i < sizeof(int64_t); ++i) {
1781    p[i] = (unsigned char)*vp;
1782    *vp >>= 8;
1783  }
1784}
1785
1786void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1787                                   AggBuffer *aggBuffer) {
1788  const DataLayout &DL = getDataLayout();
1789
1790  if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1791    int s = DL.getTypeAllocSize(CPV->getType());
1792    if (s < Bytes)
1793      s = Bytes;
1794    aggBuffer->addZeros(s);
1795    return;
1796  }
1797
1798  unsigned char ptr[8];
1799  switch (CPV->getType()->getTypeID()) {
1800
1801  case Type::IntegerTyID: {
1802    Type *ETy = CPV->getType();
1803    if (ETy == Type::getInt8Ty(CPV->getContext())) {
1804      unsigned char c = (unsigned char)cast<ConstantInt>(CPV)->getZExtValue();
1805      ConvertIntToBytes<>(ptr, c);
1806      aggBuffer->addBytes(ptr, 1, Bytes);
1807    } else if (ETy == Type::getInt16Ty(CPV->getContext())) {
1808      short int16 = (short)cast<ConstantInt>(CPV)->getZExtValue();
1809      ConvertIntToBytes<>(ptr, int16);
1810      aggBuffer->addBytes(ptr, 2, Bytes);
1811    } else if (ETy == Type::getInt32Ty(CPV->getContext())) {
1812      if (const ConstantInt *constInt = dyn_cast<ConstantInt>(CPV)) {
1813        int int32 = (int)(constInt->getZExtValue());
1814        ConvertIntToBytes<>(ptr, int32);
1815        aggBuffer->addBytes(ptr, 4, Bytes);
1816        break;
1817      } else if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1818        if (const auto *constInt = dyn_cast_or_null<ConstantInt>(
1819                ConstantFoldConstant(Cexpr, DL))) {
1820          int int32 = (int)(constInt->getZExtValue());
1821          ConvertIntToBytes<>(ptr, int32);
1822          aggBuffer->addBytes(ptr, 4, Bytes);
1823          break;
1824        }
1825        if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1826          Value *v = Cexpr->getOperand(0)->stripPointerCasts();
1827          aggBuffer->addSymbol(v, Cexpr->getOperand(0));
1828          aggBuffer->addZeros(4);
1829          break;
1830        }
1831      }
1832      llvm_unreachable("unsupported integer const type");
1833    } else if (ETy == Type::getInt64Ty(CPV->getContext())) {
1834      if (const ConstantInt *constInt = dyn_cast<ConstantInt>(CPV)) {
1835        long long int64 = (long long)(constInt->getZExtValue());
1836        ConvertIntToBytes<>(ptr, int64);
1837        aggBuffer->addBytes(ptr, 8, Bytes);
1838        break;
1839      } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1840        if (const auto *constInt = dyn_cast_or_null<ConstantInt>(
1841                ConstantFoldConstant(Cexpr, DL))) {
1842          long long int64 = (long long)(constInt->getZExtValue());
1843          ConvertIntToBytes<>(ptr, int64);
1844          aggBuffer->addBytes(ptr, 8, Bytes);
1845          break;
1846        }
1847        if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1848          Value *v = Cexpr->getOperand(0)->stripPointerCasts();
1849          aggBuffer->addSymbol(v, Cexpr->getOperand(0));
1850          aggBuffer->addZeros(8);
1851          break;
1852        }
1853      }
1854      llvm_unreachable("unsupported integer const type");
1855    } else
1856      llvm_unreachable("unsupported integer const type");
1857    break;
1858  }
1859  case Type::HalfTyID:
1860  case Type::FloatTyID:
1861  case Type::DoubleTyID: {
1862    const auto *CFP = cast<ConstantFP>(CPV);
1863    Type *Ty = CFP->getType();
1864    if (Ty == Type::getHalfTy(CPV->getContext())) {
1865      APInt API = CFP->getValueAPF().bitcastToAPInt();
1866      uint16_t float16 = API.getLoBits(16).getZExtValue();
1867      ConvertIntToBytes<>(ptr, float16);
1868      aggBuffer->addBytes(ptr, 2, Bytes);
1869    } else if (Ty == Type::getFloatTy(CPV->getContext())) {
1870      float float32 = (float) CFP->getValueAPF().convertToFloat();
1871      ConvertFloatToBytes(ptr, float32);
1872      aggBuffer->addBytes(ptr, 4, Bytes);
1873    } else if (Ty == Type::getDoubleTy(CPV->getContext())) {
1874      double float64 = CFP->getValueAPF().convertToDouble();
1875      ConvertDoubleToBytes(ptr, float64);
1876      aggBuffer->addBytes(ptr, 8, Bytes);
1877    } else {
1878      llvm_unreachable("unsupported fp const type");
1879    }
1880    break;
1881  }
1882  case Type::PointerTyID: {
1883    if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1884      aggBuffer->addSymbol(GVar, GVar);
1885    } else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1886      const Value *v = Cexpr->stripPointerCasts();
1887      aggBuffer->addSymbol(v, Cexpr);
1888    }
1889    unsigned int s = DL.getTypeAllocSize(CPV->getType());
1890    aggBuffer->addZeros(s);
1891    break;
1892  }
1893
1894  case Type::ArrayTyID:
1895  case Type::VectorTyID:
1896  case Type::StructTyID: {
1897    if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1898      int ElementSize = DL.getTypeAllocSize(CPV->getType());
1899      bufferAggregateConstant(CPV, aggBuffer);
1900      if (Bytes > ElementSize)
1901        aggBuffer->addZeros(Bytes - ElementSize);
1902    } else if (isa<ConstantAggregateZero>(CPV))
1903      aggBuffer->addZeros(Bytes);
1904    else
1905      llvm_unreachable("Unexpected Constant type");
1906    break;
1907  }
1908
1909  default:
1910    llvm_unreachable("unsupported type");
1911  }
1912}
1913
1914void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1915                                              AggBuffer *aggBuffer) {
1916  const DataLayout &DL = getDataLayout();
1917  int Bytes;
1918
1919  // Integers of arbitrary width
1920  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1921    APInt Val = CI->getValue();
1922    for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1923      uint8_t Byte = Val.getLoBits(8).getZExtValue();
1924      aggBuffer->addBytes(&Byte, 1, 1);
1925      Val.lshrInPlace(8);
1926    }
1927    return;
1928  }
1929
1930  // Old constants
1931  if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1932    if (CPV->getNumOperands())
1933      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1934        bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1935    return;
1936  }
1937
1938  if (const ConstantDataSequential *CDS =
1939          dyn_cast<ConstantDataSequential>(CPV)) {
1940    if (CDS->getNumElements())
1941      for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1942        bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1943                     aggBuffer);
1944    return;
1945  }
1946
1947  if (isa<ConstantStruct>(CPV)) {
1948    if (CPV->getNumOperands()) {
1949      StructType *ST = cast<StructType>(CPV->getType());
1950      for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1951        if (i == (e - 1))
1952          Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1953                  DL.getTypeAllocSize(ST) -
1954                  DL.getStructLayout(ST)->getElementOffset(i);
1955        else
1956          Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1957                  DL.getStructLayout(ST)->getElementOffset(i);
1958        bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1959      }
1960    }
1961    return;
1962  }
1963  llvm_unreachable("unsupported constant type in printAggregateConstant()");
1964}
1965
1966/// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
1967/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1968/// expressions that are representable in PTX and create
1969/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1970const MCExpr *
1971NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1972  MCContext &Ctx = OutContext;
1973
1974  if (CV->isNullValue() || isa<UndefValue>(CV))
1975    return MCConstantExpr::create(0, Ctx);
1976
1977  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1978    return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1979
1980  if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1981    const MCSymbolRefExpr *Expr =
1982      MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1983    if (ProcessingGeneric) {
1984      return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1985    } else {
1986      return Expr;
1987    }
1988  }
1989
1990  const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
1991  if (!CE) {
1992    llvm_unreachable("Unknown constant value to lower!");
1993  }
1994
1995  switch (CE->getOpcode()) {
1996  default:
1997    // If the code isn't optimized, there may be outstanding folding
1998    // opportunities. Attempt to fold the expression using DataLayout as a
1999    // last resort before giving up.
2000    if (Constant *C = ConstantFoldConstant(CE, getDataLayout()))
2001      if (C && C != CE)
2002        return lowerConstantForGV(C, ProcessingGeneric);
2003
2004    // Otherwise report the problem to the user.
2005    {
2006      std::string S;
2007      raw_string_ostream OS(S);
2008      OS << "Unsupported expression in static initializer: ";
2009      CE->printAsOperand(OS, /*PrintType=*/false,
2010                     !MF ? nullptr : MF->getFunction().getParent());
2011      report_fatal_error(OS.str());
2012    }
2013
2014  case Instruction::AddrSpaceCast: {
2015    // Strip the addrspacecast and pass along the operand
2016    PointerType *DstTy = cast<PointerType>(CE->getType());
2017    if (DstTy->getAddressSpace() == 0) {
2018      return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2019    }
2020    std::string S;
2021    raw_string_ostream OS(S);
2022    OS << "Unsupported expression in static initializer: ";
2023    CE->printAsOperand(OS, /*PrintType=*/ false,
2024                       !MF ? nullptr : MF->getFunction().getParent());
2025    report_fatal_error(OS.str());
2026  }
2027
2028  case Instruction::GetElementPtr: {
2029    const DataLayout &DL = getDataLayout();
2030
2031    // Generate a symbolic expression for the byte address
2032    APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2033    cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2034
2035    const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2036                                            ProcessingGeneric);
2037    if (!OffsetAI)
2038      return Base;
2039
2040    int64_t Offset = OffsetAI.getSExtValue();
2041    return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2042                                   Ctx);
2043  }
2044
2045  case Instruction::Trunc:
2046    // We emit the value and depend on the assembler to truncate the generated
2047    // expression properly.  This is important for differences between
2048    // blockaddress labels.  Since the two labels are in the same function, it
2049    // is reasonable to treat their delta as a 32-bit value.
2050    LLVM_FALLTHROUGH;
2051  case Instruction::BitCast:
2052    return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2053
2054  case Instruction::IntToPtr: {
2055    const DataLayout &DL = getDataLayout();
2056
2057    // Handle casts to pointers by changing them into casts to the appropriate
2058    // integer type.  This promotes constant folding and simplifies this code.
2059    Constant *Op = CE->getOperand(0);
2060    Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2061                                      false/*ZExt*/);
2062    return lowerConstantForGV(Op, ProcessingGeneric);
2063  }
2064
2065  case Instruction::PtrToInt: {
2066    const DataLayout &DL = getDataLayout();
2067
2068    // Support only foldable casts to/from pointers that can be eliminated by
2069    // changing the pointer to the appropriately sized integer type.
2070    Constant *Op = CE->getOperand(0);
2071    Type *Ty = CE->getType();
2072
2073    const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2074
2075    // We can emit the pointer value into this slot if the slot is an
2076    // integer slot equal to the size of the pointer.
2077    if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2078      return OpExpr;
2079
2080    // Otherwise the pointer is smaller than the resultant integer, mask off
2081    // the high bits so we are sure to get a proper truncation if the input is
2082    // a constant expr.
2083    unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2084    const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2085    return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2086  }
2087
2088  // The MC library also has a right-shift operator, but it isn't consistently
2089  // signed or unsigned between different targets.
2090  case Instruction::Add: {
2091    const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2092    const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2093    switch (CE->getOpcode()) {
2094    default: llvm_unreachable("Unknown binary operator constant cast expr");
2095    case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2096    }
2097  }
2098  }
2099}
2100
2101// Copy of MCExpr::print customized for NVPTX
2102void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2103  switch (Expr.getKind()) {
2104  case MCExpr::Target:
2105    return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2106  case MCExpr::Constant:
2107    OS << cast<MCConstantExpr>(Expr).getValue();
2108    return;
2109
2110  case MCExpr::SymbolRef: {
2111    const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2112    const MCSymbol &Sym = SRE.getSymbol();
2113    Sym.print(OS, MAI);
2114    return;
2115  }
2116
2117  case MCExpr::Unary: {
2118    const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2119    switch (UE.getOpcode()) {
2120    case MCUnaryExpr::LNot:  OS << '!'; break;
2121    case MCUnaryExpr::Minus: OS << '-'; break;
2122    case MCUnaryExpr::Not:   OS << '~'; break;
2123    case MCUnaryExpr::Plus:  OS << '+'; break;
2124    }
2125    printMCExpr(*UE.getSubExpr(), OS);
2126    return;
2127  }
2128
2129  case MCExpr::Binary: {
2130    const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2131
2132    // Only print parens around the LHS if it is non-trivial.
2133    if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2134        isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2135      printMCExpr(*BE.getLHS(), OS);
2136    } else {
2137      OS << '(';
2138      printMCExpr(*BE.getLHS(), OS);
2139      OS<< ')';
2140    }
2141
2142    switch (BE.getOpcode()) {
2143    case MCBinaryExpr::Add:
2144      // Print "X-42" instead of "X+-42".
2145      if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2146        if (RHSC->getValue() < 0) {
2147          OS << RHSC->getValue();
2148          return;
2149        }
2150      }
2151
2152      OS <<  '+';
2153      break;
2154    default: llvm_unreachable("Unhandled binary operator");
2155    }
2156
2157    // Only print parens around the LHS if it is non-trivial.
2158    if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2159      printMCExpr(*BE.getRHS(), OS);
2160    } else {
2161      OS << '(';
2162      printMCExpr(*BE.getRHS(), OS);
2163      OS << ')';
2164    }
2165    return;
2166  }
2167  }
2168
2169  llvm_unreachable("Invalid expression kind!");
2170}
2171
2172/// PrintAsmOperand - Print out an operand for an inline asm expression.
2173///
2174bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2175                                      const char *ExtraCode, raw_ostream &O) {
2176  if (ExtraCode && ExtraCode[0]) {
2177    if (ExtraCode[1] != 0)
2178      return true; // Unknown modifier.
2179
2180    switch (ExtraCode[0]) {
2181    default:
2182      // See if this is a generic print operand
2183      return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2184    case 'r':
2185      break;
2186    }
2187  }
2188
2189  printOperand(MI, OpNo, O);
2190
2191  return false;
2192}
2193
2194bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2195                                            unsigned OpNo,
2196                                            const char *ExtraCode,
2197                                            raw_ostream &O) {
2198  if (ExtraCode && ExtraCode[0])
2199    return true; // Unknown modifier
2200
2201  O << '[';
2202  printMemOperand(MI, OpNo, O);
2203  O << ']';
2204
2205  return false;
2206}
2207
2208void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
2209                                   raw_ostream &O) {
2210  const MachineOperand &MO = MI->getOperand(opNum);
2211  switch (MO.getType()) {
2212  case MachineOperand::MO_Register:
2213    if (Register::isPhysicalRegister(MO.getReg())) {
2214      if (MO.getReg() == NVPTX::VRDepot)
2215        O << DEPOTNAME << getFunctionNumber();
2216      else
2217        O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2218    } else {
2219      emitVirtualRegister(MO.getReg(), O);
2220    }
2221    break;
2222
2223  case MachineOperand::MO_Immediate:
2224    O << MO.getImm();
2225    break;
2226
2227  case MachineOperand::MO_FPImmediate:
2228    printFPConstant(MO.getFPImm(), O);
2229    break;
2230
2231  case MachineOperand::MO_GlobalAddress:
2232    PrintSymbolOperand(MO, O);
2233    break;
2234
2235  case MachineOperand::MO_MachineBasicBlock:
2236    MO.getMBB()->getSymbol()->print(O, MAI);
2237    break;
2238
2239  default:
2240    llvm_unreachable("Operand type not supported.");
2241  }
2242}
2243
2244void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
2245                                      raw_ostream &O, const char *Modifier) {
2246  printOperand(MI, opNum, O);
2247
2248  if (Modifier && strcmp(Modifier, "add") == 0) {
2249    O << ", ";
2250    printOperand(MI, opNum + 1, O);
2251  } else {
2252    if (MI->getOperand(opNum + 1).isImm() &&
2253        MI->getOperand(opNum + 1).getImm() == 0)
2254      return; // don't print ',0' or '+0'
2255    O << "+";
2256    printOperand(MI, opNum + 1, O);
2257  }
2258}
2259
2260// Force static initialization.
2261extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2262  RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2263  RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2264}
2265