1//===- Coroutines.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the common infrastructure for Coroutine Passes.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CoroInstr.h"
14#include "CoroInternal.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/Analysis/CallGraph.h"
18#include "llvm/IR/Attributes.h"
19#include "llvm/IR/Constants.h"
20#include "llvm/IR/DerivedTypes.h"
21#include "llvm/IR/Function.h"
22#include "llvm/IR/InstIterator.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/IntrinsicInst.h"
25#include "llvm/IR/Intrinsics.h"
26#include "llvm/IR/Module.h"
27#include "llvm/IR/Type.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/ErrorHandling.h"
30#include "llvm/Transforms/Utils/Local.h"
31#include <cassert>
32#include <cstddef>
33#include <utility>
34
35using namespace llvm;
36
37// Construct the lowerer base class and initialize its members.
38coro::LowererBase::LowererBase(Module &M)
39    : TheModule(M), Context(M.getContext()),
40      Int8Ptr(PointerType::get(Context, 0)),
41      ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
42                                     /*isVarArg=*/false)),
43      NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
44
45// Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
46// It generates the following:
47//
48//    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
49
50Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
51                                        Instruction *InsertPt) {
52  auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
53  auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
54
55  assert(Index >= CoroSubFnInst::IndexFirst &&
56         Index < CoroSubFnInst::IndexLast &&
57         "makeSubFnCall: Index value out of range");
58  return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
59}
60
61// NOTE: Must be sorted!
62static const char *const CoroIntrinsics[] = {
63    "llvm.coro.align",
64    "llvm.coro.alloc",
65    "llvm.coro.async.context.alloc",
66    "llvm.coro.async.context.dealloc",
67    "llvm.coro.async.resume",
68    "llvm.coro.async.size.replace",
69    "llvm.coro.async.store_resume",
70    "llvm.coro.begin",
71    "llvm.coro.destroy",
72    "llvm.coro.done",
73    "llvm.coro.end",
74    "llvm.coro.end.async",
75    "llvm.coro.frame",
76    "llvm.coro.free",
77    "llvm.coro.id",
78    "llvm.coro.id.async",
79    "llvm.coro.id.retcon",
80    "llvm.coro.id.retcon.once",
81    "llvm.coro.noop",
82    "llvm.coro.prepare.async",
83    "llvm.coro.prepare.retcon",
84    "llvm.coro.promise",
85    "llvm.coro.resume",
86    "llvm.coro.save",
87    "llvm.coro.size",
88    "llvm.coro.subfn.addr",
89    "llvm.coro.suspend",
90    "llvm.coro.suspend.async",
91    "llvm.coro.suspend.retcon",
92};
93
94#ifndef NDEBUG
95static bool isCoroutineIntrinsicName(StringRef Name) {
96  return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
97}
98#endif
99
100bool coro::declaresAnyIntrinsic(const Module &M) {
101  for (StringRef Name : CoroIntrinsics) {
102    assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
103    if (M.getNamedValue(Name))
104      return true;
105  }
106
107  return false;
108}
109
110// Verifies if a module has named values listed. Also, in debug mode verifies
111// that names are intrinsic names.
112bool coro::declaresIntrinsics(const Module &M,
113                              const std::initializer_list<StringRef> List) {
114  for (StringRef Name : List) {
115    assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
116    if (M.getNamedValue(Name))
117      return true;
118  }
119
120  return false;
121}
122
123// Replace all coro.frees associated with the provided CoroId either with 'null'
124// if Elide is true and with its frame parameter otherwise.
125void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
126  SmallVector<CoroFreeInst *, 4> CoroFrees;
127  for (User *U : CoroId->users())
128    if (auto CF = dyn_cast<CoroFreeInst>(U))
129      CoroFrees.push_back(CF);
130
131  if (CoroFrees.empty())
132    return;
133
134  Value *Replacement =
135      Elide
136          ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
137          : CoroFrees.front()->getFrame();
138
139  for (CoroFreeInst *CF : CoroFrees) {
140    CF->replaceAllUsesWith(Replacement);
141    CF->eraseFromParent();
142  }
143}
144
145static void clear(coro::Shape &Shape) {
146  Shape.CoroBegin = nullptr;
147  Shape.CoroEnds.clear();
148  Shape.CoroSizes.clear();
149  Shape.CoroSuspends.clear();
150
151  Shape.FrameTy = nullptr;
152  Shape.FramePtr = nullptr;
153  Shape.AllocaSpillBlock = nullptr;
154}
155
156static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
157                                    CoroSuspendInst *SuspendInst) {
158  Module *M = SuspendInst->getModule();
159  auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
160  auto *SaveInst =
161      cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
162  assert(!SuspendInst->getCoroSave());
163  SuspendInst->setArgOperand(0, SaveInst);
164  return SaveInst;
165}
166
167// Collect "interesting" coroutine intrinsics.
168void coro::Shape::buildFrom(Function &F) {
169  bool HasFinalSuspend = false;
170  bool HasUnwindCoroEnd = false;
171  size_t FinalSuspendIndex = 0;
172  clear(*this);
173  SmallVector<CoroFrameInst *, 8> CoroFrames;
174  SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
175
176  for (Instruction &I : instructions(F)) {
177    if (auto II = dyn_cast<IntrinsicInst>(&I)) {
178      switch (II->getIntrinsicID()) {
179      default:
180        continue;
181      case Intrinsic::coro_size:
182        CoroSizes.push_back(cast<CoroSizeInst>(II));
183        break;
184      case Intrinsic::coro_align:
185        CoroAligns.push_back(cast<CoroAlignInst>(II));
186        break;
187      case Intrinsic::coro_frame:
188        CoroFrames.push_back(cast<CoroFrameInst>(II));
189        break;
190      case Intrinsic::coro_save:
191        // After optimizations, coro_suspends using this coro_save might have
192        // been removed, remember orphaned coro_saves to remove them later.
193        if (II->use_empty())
194          UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
195        break;
196      case Intrinsic::coro_suspend_async: {
197        auto *Suspend = cast<CoroSuspendAsyncInst>(II);
198        Suspend->checkWellFormed();
199        CoroSuspends.push_back(Suspend);
200        break;
201      }
202      case Intrinsic::coro_suspend_retcon: {
203        auto Suspend = cast<CoroSuspendRetconInst>(II);
204        CoroSuspends.push_back(Suspend);
205        break;
206      }
207      case Intrinsic::coro_suspend: {
208        auto Suspend = cast<CoroSuspendInst>(II);
209        CoroSuspends.push_back(Suspend);
210        if (Suspend->isFinal()) {
211          if (HasFinalSuspend)
212            report_fatal_error(
213              "Only one suspend point can be marked as final");
214          HasFinalSuspend = true;
215          FinalSuspendIndex = CoroSuspends.size() - 1;
216        }
217        break;
218      }
219      case Intrinsic::coro_begin: {
220        auto CB = cast<CoroBeginInst>(II);
221
222        // Ignore coro id's that aren't pre-split.
223        auto Id = dyn_cast<CoroIdInst>(CB->getId());
224        if (Id && !Id->getInfo().isPreSplit())
225          break;
226
227        if (CoroBegin)
228          report_fatal_error(
229                "coroutine should have exactly one defining @llvm.coro.begin");
230        CB->addRetAttr(Attribute::NonNull);
231        CB->addRetAttr(Attribute::NoAlias);
232        CB->removeFnAttr(Attribute::NoDuplicate);
233        CoroBegin = CB;
234        break;
235      }
236      case Intrinsic::coro_end_async:
237      case Intrinsic::coro_end:
238        CoroEnds.push_back(cast<AnyCoroEndInst>(II));
239        if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
240          AsyncEnd->checkWellFormed();
241        }
242
243        if (CoroEnds.back()->isUnwind())
244          HasUnwindCoroEnd = true;
245
246        if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
247          // Make sure that the fallthrough coro.end is the first element in the
248          // CoroEnds vector.
249          // Note: I don't think this is neccessary anymore.
250          if (CoroEnds.size() > 1) {
251            if (CoroEnds.front()->isFallthrough())
252              report_fatal_error(
253                  "Only one coro.end can be marked as fallthrough");
254            std::swap(CoroEnds.front(), CoroEnds.back());
255          }
256        }
257        break;
258      }
259    }
260  }
261
262  // If for some reason, we were not able to find coro.begin, bailout.
263  if (!CoroBegin) {
264    // Replace coro.frame which are supposed to be lowered to the result of
265    // coro.begin with undef.
266    auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
267    for (CoroFrameInst *CF : CoroFrames) {
268      CF->replaceAllUsesWith(Undef);
269      CF->eraseFromParent();
270    }
271
272    // Replace all coro.suspend with undef and remove related coro.saves if
273    // present.
274    for (AnyCoroSuspendInst *CS : CoroSuspends) {
275      CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
276      CS->eraseFromParent();
277      if (auto *CoroSave = CS->getCoroSave())
278        CoroSave->eraseFromParent();
279    }
280
281    // Replace all coro.ends with unreachable instruction.
282    for (AnyCoroEndInst *CE : CoroEnds)
283      changeToUnreachable(CE);
284
285    return;
286  }
287
288  auto Id = CoroBegin->getId();
289  switch (auto IdIntrinsic = Id->getIntrinsicID()) {
290  case Intrinsic::coro_id: {
291    auto SwitchId = cast<CoroIdInst>(Id);
292    this->ABI = coro::ABI::Switch;
293    this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
294    this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
295    this->SwitchLowering.ResumeSwitch = nullptr;
296    this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
297    this->SwitchLowering.ResumeEntryBlock = nullptr;
298
299    for (auto *AnySuspend : CoroSuspends) {
300      auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
301      if (!Suspend) {
302#ifndef NDEBUG
303        AnySuspend->dump();
304#endif
305        report_fatal_error("coro.id must be paired with coro.suspend");
306      }
307
308      if (!Suspend->getCoroSave())
309        createCoroSave(CoroBegin, Suspend);
310    }
311    break;
312  }
313  case Intrinsic::coro_id_async: {
314    auto *AsyncId = cast<CoroIdAsyncInst>(Id);
315    AsyncId->checkWellFormed();
316    this->ABI = coro::ABI::Async;
317    this->AsyncLowering.Context = AsyncId->getStorage();
318    this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
319    this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
320    this->AsyncLowering.ContextAlignment =
321        AsyncId->getStorageAlignment().value();
322    this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
323    this->AsyncLowering.AsyncCC = F.getCallingConv();
324    break;
325  };
326  case Intrinsic::coro_id_retcon:
327  case Intrinsic::coro_id_retcon_once: {
328    auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
329    ContinuationId->checkWellFormed();
330    this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
331                  ? coro::ABI::Retcon
332                  : coro::ABI::RetconOnce);
333    auto Prototype = ContinuationId->getPrototype();
334    this->RetconLowering.ResumePrototype = Prototype;
335    this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
336    this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
337    this->RetconLowering.ReturnBlock = nullptr;
338    this->RetconLowering.IsFrameInlineInStorage = false;
339
340    // Determine the result value types, and make sure they match up with
341    // the values passed to the suspends.
342    auto ResultTys = getRetconResultTypes();
343    auto ResumeTys = getRetconResumeTypes();
344
345    for (auto *AnySuspend : CoroSuspends) {
346      auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
347      if (!Suspend) {
348#ifndef NDEBUG
349        AnySuspend->dump();
350#endif
351        report_fatal_error("coro.id.retcon.* must be paired with "
352                           "coro.suspend.retcon");
353      }
354
355      // Check that the argument types of the suspend match the results.
356      auto SI = Suspend->value_begin(), SE = Suspend->value_end();
357      auto RI = ResultTys.begin(), RE = ResultTys.end();
358      for (; SI != SE && RI != RE; ++SI, ++RI) {
359        auto SrcTy = (*SI)->getType();
360        if (SrcTy != *RI) {
361          // The optimizer likes to eliminate bitcasts leading into variadic
362          // calls, but that messes with our invariants.  Re-insert the
363          // bitcast and ignore this type mismatch.
364          if (CastInst::isBitCastable(SrcTy, *RI)) {
365            auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
366            SI->set(BCI);
367            continue;
368          }
369
370#ifndef NDEBUG
371          Suspend->dump();
372          Prototype->getFunctionType()->dump();
373#endif
374          report_fatal_error("argument to coro.suspend.retcon does not "
375                             "match corresponding prototype function result");
376        }
377      }
378      if (SI != SE || RI != RE) {
379#ifndef NDEBUG
380        Suspend->dump();
381        Prototype->getFunctionType()->dump();
382#endif
383        report_fatal_error("wrong number of arguments to coro.suspend.retcon");
384      }
385
386      // Check that the result type of the suspend matches the resume types.
387      Type *SResultTy = Suspend->getType();
388      ArrayRef<Type*> SuspendResultTys;
389      if (SResultTy->isVoidTy()) {
390        // leave as empty array
391      } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
392        SuspendResultTys = SResultStructTy->elements();
393      } else {
394        // forms an ArrayRef using SResultTy, be careful
395        SuspendResultTys = SResultTy;
396      }
397      if (SuspendResultTys.size() != ResumeTys.size()) {
398#ifndef NDEBUG
399        Suspend->dump();
400        Prototype->getFunctionType()->dump();
401#endif
402        report_fatal_error("wrong number of results from coro.suspend.retcon");
403      }
404      for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
405        if (SuspendResultTys[I] != ResumeTys[I]) {
406#ifndef NDEBUG
407          Suspend->dump();
408          Prototype->getFunctionType()->dump();
409#endif
410          report_fatal_error("result from coro.suspend.retcon does not "
411                             "match corresponding prototype function param");
412        }
413      }
414    }
415    break;
416  }
417
418  default:
419    llvm_unreachable("coro.begin is not dependent on a coro.id call");
420  }
421
422  // The coro.free intrinsic is always lowered to the result of coro.begin.
423  for (CoroFrameInst *CF : CoroFrames) {
424    CF->replaceAllUsesWith(CoroBegin);
425    CF->eraseFromParent();
426  }
427
428  // Move final suspend to be the last element in the CoroSuspends vector.
429  if (ABI == coro::ABI::Switch &&
430      SwitchLowering.HasFinalSuspend &&
431      FinalSuspendIndex != CoroSuspends.size() - 1)
432    std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
433
434  // Remove orphaned coro.saves.
435  for (CoroSaveInst *CoroSave : UnusedCoroSaves)
436    CoroSave->eraseFromParent();
437}
438
439static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
440  Call->setCallingConv(Callee->getCallingConv());
441  // TODO: attributes?
442}
443
444static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
445  if (CG)
446    (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
447}
448
449Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
450                              CallGraph *CG) const {
451  switch (ABI) {
452  case coro::ABI::Switch:
453    llvm_unreachable("can't allocate memory in coro switch-lowering");
454
455  case coro::ABI::Retcon:
456  case coro::ABI::RetconOnce: {
457    auto Alloc = RetconLowering.Alloc;
458    Size = Builder.CreateIntCast(Size,
459                                 Alloc->getFunctionType()->getParamType(0),
460                                 /*is signed*/ false);
461    auto *Call = Builder.CreateCall(Alloc, Size);
462    propagateCallAttrsFromCallee(Call, Alloc);
463    addCallToCallGraph(CG, Call, Alloc);
464    return Call;
465  }
466  case coro::ABI::Async:
467    llvm_unreachable("can't allocate memory in coro async-lowering");
468  }
469  llvm_unreachable("Unknown coro::ABI enum");
470}
471
472void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
473                              CallGraph *CG) const {
474  switch (ABI) {
475  case coro::ABI::Switch:
476    llvm_unreachable("can't allocate memory in coro switch-lowering");
477
478  case coro::ABI::Retcon:
479  case coro::ABI::RetconOnce: {
480    auto Dealloc = RetconLowering.Dealloc;
481    Ptr = Builder.CreateBitCast(Ptr,
482                                Dealloc->getFunctionType()->getParamType(0));
483    auto *Call = Builder.CreateCall(Dealloc, Ptr);
484    propagateCallAttrsFromCallee(Call, Dealloc);
485    addCallToCallGraph(CG, Call, Dealloc);
486    return;
487  }
488  case coro::ABI::Async:
489    llvm_unreachable("can't allocate memory in coro async-lowering");
490  }
491  llvm_unreachable("Unknown coro::ABI enum");
492}
493
494[[noreturn]] static void fail(const Instruction *I, const char *Reason,
495                              Value *V) {
496#ifndef NDEBUG
497  I->dump();
498  if (V) {
499    errs() << "  Value: ";
500    V->printAsOperand(llvm::errs());
501    errs() << '\n';
502  }
503#endif
504  report_fatal_error(Reason);
505}
506
507/// Check that the given value is a well-formed prototype for the
508/// llvm.coro.id.retcon.* intrinsics.
509static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
510  auto F = dyn_cast<Function>(V->stripPointerCasts());
511  if (!F)
512    fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
513
514  auto FT = F->getFunctionType();
515
516  if (isa<CoroIdRetconInst>(I)) {
517    bool ResultOkay;
518    if (FT->getReturnType()->isPointerTy()) {
519      ResultOkay = true;
520    } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
521      ResultOkay = (!SRetTy->isOpaque() &&
522                    SRetTy->getNumElements() > 0 &&
523                    SRetTy->getElementType(0)->isPointerTy());
524    } else {
525      ResultOkay = false;
526    }
527    if (!ResultOkay)
528      fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
529              "result", F);
530
531    if (FT->getReturnType() !=
532          I->getFunction()->getFunctionType()->getReturnType())
533      fail(I, "llvm.coro.id.retcon prototype return type must be same as"
534              "current function return type", F);
535  } else {
536    // No meaningful validation to do here for llvm.coro.id.unique.once.
537  }
538
539  if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
540    fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
541            "its first parameter", F);
542}
543
544/// Check that the given value is a well-formed allocator.
545static void checkWFAlloc(const Instruction *I, Value *V) {
546  auto F = dyn_cast<Function>(V->stripPointerCasts());
547  if (!F)
548    fail(I, "llvm.coro.* allocator not a Function", V);
549
550  auto FT = F->getFunctionType();
551  if (!FT->getReturnType()->isPointerTy())
552    fail(I, "llvm.coro.* allocator must return a pointer", F);
553
554  if (FT->getNumParams() != 1 ||
555      !FT->getParamType(0)->isIntegerTy())
556    fail(I, "llvm.coro.* allocator must take integer as only param", F);
557}
558
559/// Check that the given value is a well-formed deallocator.
560static void checkWFDealloc(const Instruction *I, Value *V) {
561  auto F = dyn_cast<Function>(V->stripPointerCasts());
562  if (!F)
563    fail(I, "llvm.coro.* deallocator not a Function", V);
564
565  auto FT = F->getFunctionType();
566  if (!FT->getReturnType()->isVoidTy())
567    fail(I, "llvm.coro.* deallocator must return void", F);
568
569  if (FT->getNumParams() != 1 ||
570      !FT->getParamType(0)->isPointerTy())
571    fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
572}
573
574static void checkConstantInt(const Instruction *I, Value *V,
575                             const char *Reason) {
576  if (!isa<ConstantInt>(V)) {
577    fail(I, Reason, V);
578  }
579}
580
581void AnyCoroIdRetconInst::checkWellFormed() const {
582  checkConstantInt(this, getArgOperand(SizeArg),
583                   "size argument to coro.id.retcon.* must be constant");
584  checkConstantInt(this, getArgOperand(AlignArg),
585                   "alignment argument to coro.id.retcon.* must be constant");
586  checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
587  checkWFAlloc(this, getArgOperand(AllocArg));
588  checkWFDealloc(this, getArgOperand(DeallocArg));
589}
590
591static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
592  auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
593  if (!AsyncFuncPtrAddr)
594    fail(I, "llvm.coro.id.async async function pointer not a global", V);
595}
596
597void CoroIdAsyncInst::checkWellFormed() const {
598  checkConstantInt(this, getArgOperand(SizeArg),
599                   "size argument to coro.id.async must be constant");
600  checkConstantInt(this, getArgOperand(AlignArg),
601                   "alignment argument to coro.id.async must be constant");
602  checkConstantInt(this, getArgOperand(StorageArg),
603                   "storage argument offset to coro.id.async must be constant");
604  checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
605}
606
607static void checkAsyncContextProjectFunction(const Instruction *I,
608                                             Function *F) {
609  auto *FunTy = cast<FunctionType>(F->getValueType());
610  if (!FunTy->getReturnType()->isPointerTy())
611    fail(I,
612         "llvm.coro.suspend.async resume function projection function must "
613         "return a ptr type",
614         F);
615  if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
616    fail(I,
617         "llvm.coro.suspend.async resume function projection function must "
618         "take one ptr type as parameter",
619         F);
620}
621
622void CoroSuspendAsyncInst::checkWellFormed() const {
623  checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
624}
625
626void CoroAsyncEndInst::checkWellFormed() const {
627  auto *MustTailCallFunc = getMustTailCallFunction();
628  if (!MustTailCallFunc)
629    return;
630  auto *FnTy = MustTailCallFunc->getFunctionType();
631  if (FnTy->getNumParams() != (arg_size() - 3))
632    fail(this,
633         "llvm.coro.end.async must tail call function argument type must "
634         "match the tail arguments",
635         MustTailCallFunc);
636}
637