MVETailPredication.cpp revision 360784
1//===- MVETailPredication.cpp - MVE Tail Predication ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11/// branches to help accelerate DSP applications. These two extensions can be
12/// combined to provide implicit vector predication within a low-overhead loop.
13/// The HardwareLoops pass inserts intrinsics identifying loops that the
14/// backend will attempt to convert into a low-overhead loop. The vectorizer is
15/// responsible for generating a vectorized loop in which the lanes are
16/// predicated upon the iteration counter. This pass looks at these predicated
17/// vector loops, that are targets for low-overhead loops, and prepares it for
18/// code generation. Once the vectorizer has produced a masked loop, there's a
19/// couple of final forms:
20/// - A tail-predicated loop, with implicit predication.
21/// - A loop containing multiple VCPT instructions, predicating multiple VPT
22///   blocks of instructions operating on different vector types.
23///
24/// This pass inserts the inserts the VCTP intrinsic to represent the effect of
25/// tail predication. This will be picked up by the ARM Low-overhead loop pass,
26/// which performs the final transformation to a DLSTP or WLSTP tail-predicated
27/// loop.
28
29#include "ARM.h"
30#include "ARMSubtarget.h"
31#include "llvm/Analysis/LoopInfo.h"
32#include "llvm/Analysis/LoopPass.h"
33#include "llvm/Analysis/ScalarEvolution.h"
34#include "llvm/Analysis/ScalarEvolutionExpander.h"
35#include "llvm/Analysis/ScalarEvolutionExpressions.h"
36#include "llvm/Analysis/TargetTransformInfo.h"
37#include "llvm/CodeGen/TargetPassConfig.h"
38#include "llvm/IR/IRBuilder.h"
39#include "llvm/IR/Instructions.h"
40#include "llvm/IR/IntrinsicsARM.h"
41#include "llvm/IR/PatternMatch.h"
42#include "llvm/Support/Debug.h"
43#include "llvm/Transforms/Utils/BasicBlockUtils.h"
44
45using namespace llvm;
46
47#define DEBUG_TYPE "mve-tail-predication"
48#define DESC "Transform predicated vector loops to use MVE tail predication"
49
50cl::opt<bool>
51DisableTailPredication("disable-mve-tail-predication", cl::Hidden,
52                       cl::init(true),
53                       cl::desc("Disable MVE Tail Predication"));
54namespace {
55
56class MVETailPredication : public LoopPass {
57  SmallVector<IntrinsicInst*, 4> MaskedInsts;
58  Loop *L = nullptr;
59  ScalarEvolution *SE = nullptr;
60  TargetTransformInfo *TTI = nullptr;
61
62public:
63  static char ID;
64
65  MVETailPredication() : LoopPass(ID) { }
66
67  void getAnalysisUsage(AnalysisUsage &AU) const override {
68    AU.addRequired<ScalarEvolutionWrapperPass>();
69    AU.addRequired<LoopInfoWrapperPass>();
70    AU.addRequired<TargetPassConfig>();
71    AU.addRequired<TargetTransformInfoWrapperPass>();
72    AU.addPreserved<LoopInfoWrapperPass>();
73    AU.setPreservesCFG();
74  }
75
76  bool runOnLoop(Loop *L, LPPassManager&) override;
77
78private:
79
80  /// Perform the relevant checks on the loop and convert if possible.
81  bool TryConvert(Value *TripCount);
82
83  /// Return whether this is a vectorized loop, that contains masked
84  /// load/stores.
85  bool IsPredicatedVectorLoop();
86
87  /// Compute a value for the total number of elements that the predicated
88  /// loop will process.
89  Value *ComputeElements(Value *TripCount, VectorType *VecTy);
90
91  /// Is the icmp that generates an i1 vector, based upon a loop counter
92  /// and a limit that is defined outside the loop.
93  bool isTailPredicate(Instruction *Predicate, Value *NumElements);
94
95  /// Insert the intrinsic to represent the effect of tail predication.
96  void InsertVCTPIntrinsic(Instruction *Predicate,
97                           DenseMap<Instruction*, Instruction*> &NewPredicates,
98                           VectorType *VecTy,
99                           Value *NumElements);
100};
101
102} // end namespace
103
104static bool IsDecrement(Instruction &I) {
105  auto *Call = dyn_cast<IntrinsicInst>(&I);
106  if (!Call)
107    return false;
108
109  Intrinsic::ID ID = Call->getIntrinsicID();
110  return ID == Intrinsic::loop_decrement_reg;
111}
112
113static bool IsMasked(Instruction *I) {
114  auto *Call = dyn_cast<IntrinsicInst>(I);
115  if (!Call)
116    return false;
117
118  Intrinsic::ID ID = Call->getIntrinsicID();
119  // TODO: Support gather/scatter expand/compress operations.
120  return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load;
121}
122
123bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
124  if (skipLoop(L) || DisableTailPredication)
125    return false;
126
127  Function &F = *L->getHeader()->getParent();
128  auto &TPC = getAnalysis<TargetPassConfig>();
129  auto &TM = TPC.getTM<TargetMachine>();
130  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
131  TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
132  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
133  this->L = L;
134
135  // The MVE and LOB extensions are combined to enable tail-predication, but
136  // there's nothing preventing us from generating VCTP instructions for v8.1m.
137  if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
138    LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
139    return false;
140  }
141
142  BasicBlock *Preheader = L->getLoopPreheader();
143  if (!Preheader)
144    return false;
145
146  auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
147    for (auto &I : *BB) {
148      auto *Call = dyn_cast<IntrinsicInst>(&I);
149      if (!Call)
150        continue;
151
152      Intrinsic::ID ID = Call->getIntrinsicID();
153      if (ID == Intrinsic::set_loop_iterations ||
154          ID == Intrinsic::test_set_loop_iterations)
155        return cast<IntrinsicInst>(&I);
156    }
157    return nullptr;
158  };
159
160  // Look for the hardware loop intrinsic that sets the iteration count.
161  IntrinsicInst *Setup = FindLoopIterations(Preheader);
162
163  // The test.set iteration could live in the pre-preheader.
164  if (!Setup) {
165    if (!Preheader->getSinglePredecessor())
166      return false;
167    Setup = FindLoopIterations(Preheader->getSinglePredecessor());
168    if (!Setup)
169      return false;
170  }
171
172  // Search for the hardware loop intrinic that decrements the loop counter.
173  IntrinsicInst *Decrement = nullptr;
174  for (auto *BB : L->getBlocks()) {
175    for (auto &I : *BB) {
176      if (IsDecrement(I)) {
177        Decrement = cast<IntrinsicInst>(&I);
178        break;
179      }
180    }
181  }
182
183  if (!Decrement)
184    return false;
185
186  LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n"
187             << *Decrement << "\n");
188  return TryConvert(Setup->getArgOperand(0));
189}
190
191bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) {
192  // Look for the following:
193
194  // %trip.count.minus.1 = add i32 %N, -1
195  // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
196  //                                          i32 %trip.count.minus.1, i32 0
197  // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
198  //                                    <4 x i32> undef,
199  //                                    <4 x i32> zeroinitializer
200  // ...
201  // ...
202  // %index = phi i32
203  // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
204  // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
205  //                                  <4 x i32> undef,
206  //                                  <4 x i32> zeroinitializer
207  // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
208  // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
209
210  // And return whether V == %pred.
211
212  using namespace PatternMatch;
213
214  CmpInst::Predicate Pred;
215  Instruction *Shuffle = nullptr;
216  Instruction *Induction = nullptr;
217
218  // The vector icmp
219  if (!match(I, m_ICmp(Pred, m_Instruction(Induction),
220                       m_Instruction(Shuffle))) ||
221      Pred != ICmpInst::ICMP_ULE)
222    return false;
223
224  // First find the stuff outside the loop which is setting up the limit
225  // vector....
226  // The invariant shuffle that broadcast the limit into a vector.
227  Instruction *Insert = nullptr;
228  if (!match(Shuffle, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
229                                      m_Zero())))
230    return false;
231
232  // Insert the limit into a vector.
233  Instruction *BECount = nullptr;
234  if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(BECount),
235                                     m_Zero())))
236    return false;
237
238  // The limit calculation, backedge count.
239  Value *TripCount = nullptr;
240  if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes())))
241    return false;
242
243  if (TripCount != NumElements || !L->isLoopInvariant(BECount))
244    return false;
245
246  // Now back to searching inside the loop body...
247  // Find the add with takes the index iv and adds a constant vector to it.
248  Instruction *BroadcastSplat = nullptr;
249  Constant *Const = nullptr;
250  if (!match(Induction, m_Add(m_Instruction(BroadcastSplat),
251                              m_Constant(Const))))
252   return false;
253
254  // Check that we're adding <0, 1, 2, 3...
255  if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) {
256    for (unsigned i = 0; i < CDS->getNumElements(); ++i) {
257      if (CDS->getElementAsInteger(i) != i)
258        return false;
259    }
260  } else
261    return false;
262
263  // The shuffle which broadcasts the index iv into a vector.
264  if (!match(BroadcastSplat, m_ShuffleVector(m_Instruction(Insert), m_Undef(),
265                                             m_Zero())))
266    return false;
267
268  // The insert element which initialises a vector with the index iv.
269  Instruction *IV = nullptr;
270  if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero())))
271    return false;
272
273  // The index iv.
274  auto *Phi = dyn_cast<PHINode>(IV);
275  if (!Phi)
276    return false;
277
278  // TODO: Don't think we need to check the entry value.
279  Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader());
280  if (!match(OnEntry, m_Zero()))
281    return false;
282
283  Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch());
284  unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements();
285
286  Instruction *LHS = nullptr;
287  if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes))))
288    return false;
289
290  return LHS == Phi;
291}
292
293static VectorType* getVectorType(IntrinsicInst *I) {
294  unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1;
295  auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType());
296  return cast<VectorType>(PtrTy->getElementType());
297}
298
299bool MVETailPredication::IsPredicatedVectorLoop() {
300  // Check that the loop contains at least one masked load/store intrinsic.
301  // We only support 'normal' vector instructions - other than masked
302  // load/stores.
303  for (auto *BB : L->getBlocks()) {
304    for (auto &I : *BB) {
305      if (IsMasked(&I)) {
306        VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I));
307        unsigned Lanes = VecTy->getNumElements();
308        unsigned ElementWidth = VecTy->getScalarSizeInBits();
309        // MVE vectors are 128-bit, but don't support 128 x i1.
310        // TODO: Can we support vectors larger than 128-bits?
311        unsigned MaxWidth = TTI->getRegisterBitWidth(true);
312        if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth)
313          return false;
314        MaskedInsts.push_back(cast<IntrinsicInst>(&I));
315      } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) {
316        for (auto &U : Int->args()) {
317          if (isa<VectorType>(U->getType()))
318            return false;
319        }
320      }
321    }
322  }
323
324  return !MaskedInsts.empty();
325}
326
327Value* MVETailPredication::ComputeElements(Value *TripCount,
328                                           VectorType *VecTy) {
329  const SCEV *TripCountSE = SE->getSCEV(TripCount);
330  ConstantInt *VF = ConstantInt::get(cast<IntegerType>(TripCount->getType()),
331                                     VecTy->getNumElements());
332
333  if (VF->equalsInt(1))
334    return nullptr;
335
336  // TODO: Support constant trip counts.
337  auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr* {
338    if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
339      if (Const->getAPInt() != -VF->getValue())
340        return nullptr;
341    } else
342      return nullptr;
343    return dyn_cast<SCEVMulExpr>(S->getOperand(1));
344  };
345
346  auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr* {
347    if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) {
348      if (Const->getValue() != VF)
349        return nullptr;
350    } else
351      return nullptr;
352    return dyn_cast<SCEVUDivExpr>(S->getOperand(1));
353  };
354
355  auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV* {
356    if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) {
357      if (Const->getValue() != VF)
358        return nullptr;
359    } else
360      return nullptr;
361
362    if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) {
363      if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) {
364        if (Const->getAPInt() != (VF->getValue() - 1))
365          return nullptr;
366      } else
367        return nullptr;
368
369      return RoundUp->getOperand(1);
370    }
371    return nullptr;
372  };
373
374  // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
375  // determine the numbers of elements instead? Looks like this is what is used
376  // for delinearization, but I'm not sure if it can be applied to the
377  // vectorized form - at least not without a bit more work than I feel
378  // comfortable with.
379
380  // Search for Elems in the following SCEV:
381  // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
382  const SCEV *Elems = nullptr;
383  if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE))
384    if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1)))
385      if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS()))
386        if (auto *Mul = VisitAdd(Add))
387          if (auto *Div = VisitMul(Mul))
388            if (auto *Res = VisitDiv(Div))
389              Elems = Res;
390
391  if (!Elems)
392    return nullptr;
393
394  Instruction *InsertPt = L->getLoopPreheader()->getTerminator();
395  if (!isSafeToExpandAt(Elems, InsertPt, *SE))
396    return nullptr;
397
398  auto DL = L->getHeader()->getModule()->getDataLayout();
399  SCEVExpander Expander(*SE, DL, "elements");
400  return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt);
401}
402
403// Look through the exit block to see whether there's a duplicate predicate
404// instruction. This can happen when we need to perform a select on values
405// from the last and previous iteration. Instead of doing a straight
406// replacement of that predicate with the vctp, clone the vctp and place it
407// in the block. This means that the VPR doesn't have to be live into the
408// exit block which should make it easier to convert this loop into a proper
409// tail predicated loop.
410static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates,
411                    SetVector<Instruction*> &MaybeDead, Loop *L) {
412  BasicBlock *Exit = L->getUniqueExitBlock();
413  if (!Exit) {
414    LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n");
415    return;
416  }
417
418  for (auto &Pair : NewPredicates) {
419    Instruction *OldPred = Pair.first;
420    Instruction *NewPred = Pair.second;
421
422    for (auto &I : *Exit) {
423      if (I.isSameOperationAs(OldPred)) {
424        Instruction *PredClone = NewPred->clone();
425        PredClone->insertBefore(&I);
426        I.replaceAllUsesWith(PredClone);
427        MaybeDead.insert(&I);
428        LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump();
429                   dbgs() << "ARM TP: with:      "; PredClone->dump());
430        break;
431      }
432    }
433  }
434
435  // Drop references and add operands to check for dead.
436  SmallPtrSet<Instruction*, 4> Dead;
437  while (!MaybeDead.empty()) {
438    auto *I = MaybeDead.front();
439    MaybeDead.remove(I);
440    if (I->hasNUsesOrMore(1))
441      continue;
442
443    for (auto &U : I->operands()) {
444      if (auto *OpI = dyn_cast<Instruction>(U))
445        MaybeDead.insert(OpI);
446    }
447    I->dropAllReferences();
448    Dead.insert(I);
449  }
450
451  for (auto *I : Dead) {
452    LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump());
453    I->eraseFromParent();
454  }
455
456  for (auto I : L->blocks())
457    DeleteDeadPHIs(I);
458}
459
460void MVETailPredication::InsertVCTPIntrinsic(Instruction *Predicate,
461    DenseMap<Instruction*, Instruction*> &NewPredicates,
462    VectorType *VecTy, Value *NumElements) {
463  IRBuilder<> Builder(L->getHeader()->getFirstNonPHI());
464  Module *M = L->getHeader()->getModule();
465  Type *Ty = IntegerType::get(M->getContext(), 32);
466
467  // Insert a phi to count the number of elements processed by the loop.
468  PHINode *Processed = Builder.CreatePHI(Ty, 2);
469  Processed->addIncoming(NumElements, L->getLoopPreheader());
470
471  // Insert the intrinsic to represent the effect of tail predication.
472  Builder.SetInsertPoint(cast<Instruction>(Predicate));
473  ConstantInt *Factor =
474    ConstantInt::get(cast<IntegerType>(Ty), VecTy->getNumElements());
475
476  Intrinsic::ID VCTPID;
477  switch (VecTy->getNumElements()) {
478  default:
479    llvm_unreachable("unexpected number of lanes");
480  case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
481  case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
482  case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
483
484    // FIXME: vctp64 currently not supported because the predicate
485    // vector wants to be <2 x i1>, but v2i1 is not a legal MVE
486    // type, so problems happen at isel time.
487    // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics
488    // purposes, but takes a v4i1 instead of a v2i1.
489  }
490  Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
491  Value *TailPredicate = Builder.CreateCall(VCTP, Processed);
492  Predicate->replaceAllUsesWith(TailPredicate);
493  NewPredicates[Predicate] = cast<Instruction>(TailPredicate);
494
495  // Add the incoming value to the new phi.
496  // TODO: This add likely already exists in the loop.
497  Value *Remaining = Builder.CreateSub(Processed, Factor);
498  Processed->addIncoming(Remaining, L->getLoopLatch());
499  LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
500             << *Processed << "\n"
501             << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n");
502}
503
504bool MVETailPredication::TryConvert(Value *TripCount) {
505  if (!IsPredicatedVectorLoop()) {
506    LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop");
507    return false;
508  }
509
510  LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
511
512  // Walk through the masked intrinsics and try to find whether the predicate
513  // operand is generated from an induction variable.
514  SetVector<Instruction*> Predicates;
515  DenseMap<Instruction*, Instruction*> NewPredicates;
516
517  for (auto *I : MaskedInsts) {
518    Intrinsic::ID ID = I->getIntrinsicID();
519    unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3;
520    auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp));
521    if (!Predicate || Predicates.count(Predicate))
522      continue;
523
524    VectorType *VecTy = getVectorType(I);
525    Value *NumElements = ComputeElements(TripCount, VecTy);
526    if (!NumElements)
527      continue;
528
529    if (!isTailPredicate(Predicate, NumElements)) {
530      LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n");
531      continue;
532    }
533
534    LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n");
535    Predicates.insert(Predicate);
536
537    InsertVCTPIntrinsic(Predicate, NewPredicates, VecTy, NumElements);
538  }
539
540  // Now clean up.
541  Cleanup(NewPredicates, Predicates, L);
542  return true;
543}
544
545Pass *llvm::createMVETailPredicationPass() {
546  return new MVETailPredication();
547}
548
549char MVETailPredication::ID = 0;
550
551INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
552INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
553