1//===- LoopIdiomRecognize.cpp - Loop idiom recognition --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements an idiom recognizer that transforms simple loops into a
10// non-loop form.  In cases that this kicks in, it can be a significant
11// performance win.
12//
13// If compiling for code size we avoid idiom recognition if the resulting
14// code could be larger than the code for the original loop. One way this could
15// happen is if the loop is not removable after idiom recognition due to the
16// presence of non-idiom instructions. The initial implementation of the
17// heuristics applies to idioms in multi-block loops.
18//
19//===----------------------------------------------------------------------===//
20//
21// TODO List:
22//
23// Future loop memory idioms to recognize:
24//   memcmp, strlen, etc.
25// Future floating point idioms to recognize in -ffast-math mode:
26//   fpowi
27//
28// This could recognize common matrix multiplies and dot product idioms and
29// replace them with calls to BLAS (if linked in??).
30//
31//===----------------------------------------------------------------------===//
32
33#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h"
34#include "llvm/ADT/APInt.h"
35#include "llvm/ADT/ArrayRef.h"
36#include "llvm/ADT/DenseMap.h"
37#include "llvm/ADT/MapVector.h"
38#include "llvm/ADT/SetVector.h"
39#include "llvm/ADT/SmallPtrSet.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/Statistic.h"
42#include "llvm/ADT/StringRef.h"
43#include "llvm/Analysis/AliasAnalysis.h"
44#include "llvm/Analysis/CmpInstAnalysis.h"
45#include "llvm/Analysis/LoopAccessAnalysis.h"
46#include "llvm/Analysis/LoopInfo.h"
47#include "llvm/Analysis/LoopPass.h"
48#include "llvm/Analysis/MemoryLocation.h"
49#include "llvm/Analysis/MemorySSA.h"
50#include "llvm/Analysis/MemorySSAUpdater.h"
51#include "llvm/Analysis/MustExecute.h"
52#include "llvm/Analysis/OptimizationRemarkEmitter.h"
53#include "llvm/Analysis/ScalarEvolution.h"
54#include "llvm/Analysis/ScalarEvolutionExpressions.h"
55#include "llvm/Analysis/TargetLibraryInfo.h"
56#include "llvm/Analysis/TargetTransformInfo.h"
57#include "llvm/Analysis/ValueTracking.h"
58#include "llvm/IR/BasicBlock.h"
59#include "llvm/IR/Constant.h"
60#include "llvm/IR/Constants.h"
61#include "llvm/IR/DataLayout.h"
62#include "llvm/IR/DebugLoc.h"
63#include "llvm/IR/DerivedTypes.h"
64#include "llvm/IR/Dominators.h"
65#include "llvm/IR/GlobalValue.h"
66#include "llvm/IR/GlobalVariable.h"
67#include "llvm/IR/IRBuilder.h"
68#include "llvm/IR/InstrTypes.h"
69#include "llvm/IR/Instruction.h"
70#include "llvm/IR/Instructions.h"
71#include "llvm/IR/IntrinsicInst.h"
72#include "llvm/IR/Intrinsics.h"
73#include "llvm/IR/LLVMContext.h"
74#include "llvm/IR/Module.h"
75#include "llvm/IR/PassManager.h"
76#include "llvm/IR/PatternMatch.h"
77#include "llvm/IR/Type.h"
78#include "llvm/IR/User.h"
79#include "llvm/IR/Value.h"
80#include "llvm/IR/ValueHandle.h"
81#include "llvm/Support/Casting.h"
82#include "llvm/Support/CommandLine.h"
83#include "llvm/Support/Debug.h"
84#include "llvm/Support/InstructionCost.h"
85#include "llvm/Support/raw_ostream.h"
86#include "llvm/Transforms/Utils/BuildLibCalls.h"
87#include "llvm/Transforms/Utils/Local.h"
88#include "llvm/Transforms/Utils/LoopUtils.h"
89#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
90#include <algorithm>
91#include <cassert>
92#include <cstdint>
93#include <utility>
94#include <vector>
95
96using namespace llvm;
97
98#define DEBUG_TYPE "loop-idiom"
99
100STATISTIC(NumMemSet, "Number of memset's formed from loop stores");
101STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores");
102STATISTIC(NumMemMove, "Number of memmove's formed from loop load+stores");
103STATISTIC(
104    NumShiftUntilBitTest,
105    "Number of uncountable loops recognized as 'shift until bitttest' idiom");
106STATISTIC(NumShiftUntilZero,
107          "Number of uncountable loops recognized as 'shift until zero' idiom");
108
109bool DisableLIRP::All;
110static cl::opt<bool, true>
111    DisableLIRPAll("disable-" DEBUG_TYPE "-all",
112                   cl::desc("Options to disable Loop Idiom Recognize Pass."),
113                   cl::location(DisableLIRP::All), cl::init(false),
114                   cl::ReallyHidden);
115
116bool DisableLIRP::Memset;
117static cl::opt<bool, true>
118    DisableLIRPMemset("disable-" DEBUG_TYPE "-memset",
119                      cl::desc("Proceed with loop idiom recognize pass, but do "
120                               "not convert loop(s) to memset."),
121                      cl::location(DisableLIRP::Memset), cl::init(false),
122                      cl::ReallyHidden);
123
124bool DisableLIRP::Memcpy;
125static cl::opt<bool, true>
126    DisableLIRPMemcpy("disable-" DEBUG_TYPE "-memcpy",
127                      cl::desc("Proceed with loop idiom recognize pass, but do "
128                               "not convert loop(s) to memcpy."),
129                      cl::location(DisableLIRP::Memcpy), cl::init(false),
130                      cl::ReallyHidden);
131
132static cl::opt<bool> UseLIRCodeSizeHeurs(
133    "use-lir-code-size-heurs",
134    cl::desc("Use loop idiom recognition code size heuristics when compiling"
135             "with -Os/-Oz"),
136    cl::init(true), cl::Hidden);
137
138namespace {
139
140class LoopIdiomRecognize {
141  Loop *CurLoop = nullptr;
142  AliasAnalysis *AA;
143  DominatorTree *DT;
144  LoopInfo *LI;
145  ScalarEvolution *SE;
146  TargetLibraryInfo *TLI;
147  const TargetTransformInfo *TTI;
148  const DataLayout *DL;
149  OptimizationRemarkEmitter &ORE;
150  bool ApplyCodeSizeHeuristics;
151  std::unique_ptr<MemorySSAUpdater> MSSAU;
152
153public:
154  explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT,
155                              LoopInfo *LI, ScalarEvolution *SE,
156                              TargetLibraryInfo *TLI,
157                              const TargetTransformInfo *TTI, MemorySSA *MSSA,
158                              const DataLayout *DL,
159                              OptimizationRemarkEmitter &ORE)
160      : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {
161    if (MSSA)
162      MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
163  }
164
165  bool runOnLoop(Loop *L);
166
167private:
168  using StoreList = SmallVector<StoreInst *, 8>;
169  using StoreListMap = MapVector<Value *, StoreList>;
170
171  StoreListMap StoreRefsForMemset;
172  StoreListMap StoreRefsForMemsetPattern;
173  StoreList StoreRefsForMemcpy;
174  bool HasMemset;
175  bool HasMemsetPattern;
176  bool HasMemcpy;
177
178  /// Return code for isLegalStore()
179  enum LegalStoreKind {
180    None = 0,
181    Memset,
182    MemsetPattern,
183    Memcpy,
184    UnorderedAtomicMemcpy,
185    DontUse // Dummy retval never to be used. Allows catching errors in retval
186            // handling.
187  };
188
189  /// \name Countable Loop Idiom Handling
190  /// @{
191
192  bool runOnCountableLoop();
193  bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
194                      SmallVectorImpl<BasicBlock *> &ExitBlocks);
195
196  void collectStores(BasicBlock *BB);
197  LegalStoreKind isLegalStore(StoreInst *SI);
198  enum class ForMemset { No, Yes };
199  bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount,
200                         ForMemset For);
201
202  template <typename MemInst>
203  bool processLoopMemIntrinsic(
204      BasicBlock *BB,
205      bool (LoopIdiomRecognize::*Processor)(MemInst *, const SCEV *),
206      const SCEV *BECount);
207  bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
208  bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
209
210  bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
211                               MaybeAlign StoreAlignment, Value *StoredVal,
212                               Instruction *TheStore,
213                               SmallPtrSetImpl<Instruction *> &Stores,
214                               const SCEVAddRecExpr *Ev, const SCEV *BECount,
215                               bool IsNegStride, bool IsLoopMemset = false);
216  bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount);
217  bool processLoopStoreOfLoopLoad(Value *DestPtr, Value *SourcePtr,
218                                  const SCEV *StoreSize, MaybeAlign StoreAlign,
219                                  MaybeAlign LoadAlign, Instruction *TheStore,
220                                  Instruction *TheLoad,
221                                  const SCEVAddRecExpr *StoreEv,
222                                  const SCEVAddRecExpr *LoadEv,
223                                  const SCEV *BECount);
224  bool avoidLIRForMultiBlockLoop(bool IsMemset = false,
225                                 bool IsLoopMemset = false);
226
227  /// @}
228  /// \name Noncountable Loop Idiom Handling
229  /// @{
230
231  bool runOnNoncountableLoop();
232
233  bool recognizePopcount();
234  void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst,
235                               PHINode *CntPhi, Value *Var);
236  bool recognizeAndInsertFFS();  /// Find First Set: ctlz or cttz
237  void transformLoopToCountable(Intrinsic::ID IntrinID, BasicBlock *PreCondBB,
238                                Instruction *CntInst, PHINode *CntPhi,
239                                Value *Var, Instruction *DefX,
240                                const DebugLoc &DL, bool ZeroCheck,
241                                bool IsCntPhiUsedOutsideLoop);
242
243  bool recognizeShiftUntilBitTest();
244  bool recognizeShiftUntilZero();
245
246  /// @}
247};
248} // end anonymous namespace
249
250PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
251                                              LoopStandardAnalysisResults &AR,
252                                              LPMUpdater &) {
253  if (DisableLIRP::All)
254    return PreservedAnalyses::all();
255
256  const auto *DL = &L.getHeader()->getModule()->getDataLayout();
257
258  // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis
259  // pass.  Function analyses need to be preserved across loop transformations
260  // but ORE cannot be preserved (see comment before the pass definition).
261  OptimizationRemarkEmitter ORE(L.getHeader()->getParent());
262
263  LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI,
264                         AR.MSSA, DL, ORE);
265  if (!LIR.runOnLoop(&L))
266    return PreservedAnalyses::all();
267
268  auto PA = getLoopPassPreservedAnalyses();
269  if (AR.MSSA)
270    PA.preserve<MemorySSAAnalysis>();
271  return PA;
272}
273
274static void deleteDeadInstruction(Instruction *I) {
275  I->replaceAllUsesWith(PoisonValue::get(I->getType()));
276  I->eraseFromParent();
277}
278
279//===----------------------------------------------------------------------===//
280//
281//          Implementation of LoopIdiomRecognize
282//
283//===----------------------------------------------------------------------===//
284
285bool LoopIdiomRecognize::runOnLoop(Loop *L) {
286  CurLoop = L;
287  // If the loop could not be converted to canonical form, it must have an
288  // indirectbr in it, just give up.
289  if (!L->getLoopPreheader())
290    return false;
291
292  // Disable loop idiom recognition if the function's name is a common idiom.
293  StringRef Name = L->getHeader()->getParent()->getName();
294  if (Name == "memset" || Name == "memcpy")
295    return false;
296
297  // Determine if code size heuristics need to be applied.
298  ApplyCodeSizeHeuristics =
299      L->getHeader()->getParent()->hasOptSize() && UseLIRCodeSizeHeurs;
300
301  HasMemset = TLI->has(LibFunc_memset);
302  HasMemsetPattern = TLI->has(LibFunc_memset_pattern16);
303  HasMemcpy = TLI->has(LibFunc_memcpy);
304
305  if (HasMemset || HasMemsetPattern || HasMemcpy)
306    if (SE->hasLoopInvariantBackedgeTakenCount(L))
307      return runOnCountableLoop();
308
309  return runOnNoncountableLoop();
310}
311
312bool LoopIdiomRecognize::runOnCountableLoop() {
313  const SCEV *BECount = SE->getBackedgeTakenCount(CurLoop);
314  assert(!isa<SCEVCouldNotCompute>(BECount) &&
315         "runOnCountableLoop() called on a loop without a predictable"
316         "backedge-taken count");
317
318  // If this loop executes exactly one time, then it should be peeled, not
319  // optimized by this pass.
320  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
321    if (BECst->getAPInt() == 0)
322      return false;
323
324  SmallVector<BasicBlock *, 8> ExitBlocks;
325  CurLoop->getUniqueExitBlocks(ExitBlocks);
326
327  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
328                    << CurLoop->getHeader()->getParent()->getName()
329                    << "] Countable Loop %" << CurLoop->getHeader()->getName()
330                    << "\n");
331
332  // The following transforms hoist stores/memsets into the loop pre-header.
333  // Give up if the loop has instructions that may throw.
334  SimpleLoopSafetyInfo SafetyInfo;
335  SafetyInfo.computeLoopSafetyInfo(CurLoop);
336  if (SafetyInfo.anyBlockMayThrow())
337    return false;
338
339  bool MadeChange = false;
340
341  // Scan all the blocks in the loop that are not in subloops.
342  for (auto *BB : CurLoop->getBlocks()) {
343    // Ignore blocks in subloops.
344    if (LI->getLoopFor(BB) != CurLoop)
345      continue;
346
347    MadeChange |= runOnLoopBlock(BB, BECount, ExitBlocks);
348  }
349  return MadeChange;
350}
351
352static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) {
353  const SCEVConstant *ConstStride = cast<SCEVConstant>(StoreEv->getOperand(1));
354  return ConstStride->getAPInt();
355}
356
357/// getMemSetPatternValue - If a strided store of the specified value is safe to
358/// turn into a memset_pattern16, return a ConstantArray of 16 bytes that should
359/// be passed in.  Otherwise, return null.
360///
361/// Note that we don't ever attempt to use memset_pattern8 or 4, because these
362/// just replicate their input array and then pass on to memset_pattern16.
363static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) {
364  // FIXME: This could check for UndefValue because it can be merged into any
365  // other valid pattern.
366
367  // If the value isn't a constant, we can't promote it to being in a constant
368  // array.  We could theoretically do a store to an alloca or something, but
369  // that doesn't seem worthwhile.
370  Constant *C = dyn_cast<Constant>(V);
371  if (!C || isa<ConstantExpr>(C))
372    return nullptr;
373
374  // Only handle simple values that are a power of two bytes in size.
375  uint64_t Size = DL->getTypeSizeInBits(V->getType());
376  if (Size == 0 || (Size & 7) || (Size & (Size - 1)))
377    return nullptr;
378
379  // Don't care enough about darwin/ppc to implement this.
380  if (DL->isBigEndian())
381    return nullptr;
382
383  // Convert to size in bytes.
384  Size /= 8;
385
386  // TODO: If CI is larger than 16-bytes, we can try slicing it in half to see
387  // if the top and bottom are the same (e.g. for vectors and large integers).
388  if (Size > 16)
389    return nullptr;
390
391  // If the constant is exactly 16 bytes, just use it.
392  if (Size == 16)
393    return C;
394
395  // Otherwise, we'll use an array of the constants.
396  unsigned ArraySize = 16 / Size;
397  ArrayType *AT = ArrayType::get(V->getType(), ArraySize);
398  return ConstantArray::get(AT, std::vector<Constant *>(ArraySize, C));
399}
400
401LoopIdiomRecognize::LegalStoreKind
402LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
403  // Don't touch volatile stores.
404  if (SI->isVolatile())
405    return LegalStoreKind::None;
406  // We only want simple or unordered-atomic stores.
407  if (!SI->isUnordered())
408    return LegalStoreKind::None;
409
410  // Avoid merging nontemporal stores.
411  if (SI->getMetadata(LLVMContext::MD_nontemporal))
412    return LegalStoreKind::None;
413
414  Value *StoredVal = SI->getValueOperand();
415  Value *StorePtr = SI->getPointerOperand();
416
417  // Don't convert stores of non-integral pointer types to memsets (which stores
418  // integers).
419  if (DL->isNonIntegralPointerType(StoredVal->getType()->getScalarType()))
420    return LegalStoreKind::None;
421
422  // Reject stores that are so large that they overflow an unsigned.
423  // When storing out scalable vectors we bail out for now, since the code
424  // below currently only works for constant strides.
425  TypeSize SizeInBits = DL->getTypeSizeInBits(StoredVal->getType());
426  if (SizeInBits.isScalable() || (SizeInBits.getFixedValue() & 7) ||
427      (SizeInBits.getFixedValue() >> 32) != 0)
428    return LegalStoreKind::None;
429
430  // See if the pointer expression is an AddRec like {base,+,1} on the current
431  // loop, which indicates a strided store.  If we have something else, it's a
432  // random store we can't handle.
433  const SCEVAddRecExpr *StoreEv =
434      dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
435  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
436    return LegalStoreKind::None;
437
438  // Check to see if we have a constant stride.
439  if (!isa<SCEVConstant>(StoreEv->getOperand(1)))
440    return LegalStoreKind::None;
441
442  // See if the store can be turned into a memset.
443
444  // If the stored value is a byte-wise value (like i32 -1), then it may be
445  // turned into a memset of i8 -1, assuming that all the consecutive bytes
446  // are stored.  A store of i32 0x01020304 can never be turned into a memset,
447  // but it can be turned into memset_pattern if the target supports it.
448  Value *SplatValue = isBytewiseValue(StoredVal, *DL);
449
450  // Note: memset and memset_pattern on unordered-atomic is yet not supported
451  bool UnorderedAtomic = SI->isUnordered() && !SI->isSimple();
452
453  // If we're allowed to form a memset, and the stored value would be
454  // acceptable for memset, use it.
455  if (!UnorderedAtomic && HasMemset && SplatValue && !DisableLIRP::Memset &&
456      // Verify that the stored value is loop invariant.  If not, we can't
457      // promote the memset.
458      CurLoop->isLoopInvariant(SplatValue)) {
459    // It looks like we can use SplatValue.
460    return LegalStoreKind::Memset;
461  }
462  if (!UnorderedAtomic && HasMemsetPattern && !DisableLIRP::Memset &&
463      // Don't create memset_pattern16s with address spaces.
464      StorePtr->getType()->getPointerAddressSpace() == 0 &&
465      getMemSetPatternValue(StoredVal, DL)) {
466    // It looks like we can use PatternValue!
467    return LegalStoreKind::MemsetPattern;
468  }
469
470  // Otherwise, see if the store can be turned into a memcpy.
471  if (HasMemcpy && !DisableLIRP::Memcpy) {
472    // Check to see if the stride matches the size of the store.  If so, then we
473    // know that every byte is touched in the loop.
474    APInt Stride = getStoreStride(StoreEv);
475    unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
476    if (StoreSize != Stride && StoreSize != -Stride)
477      return LegalStoreKind::None;
478
479    // The store must be feeding a non-volatile load.
480    LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
481
482    // Only allow non-volatile loads
483    if (!LI || LI->isVolatile())
484      return LegalStoreKind::None;
485    // Only allow simple or unordered-atomic loads
486    if (!LI->isUnordered())
487      return LegalStoreKind::None;
488
489    // See if the pointer expression is an AddRec like {base,+,1} on the current
490    // loop, which indicates a strided load.  If we have something else, it's a
491    // random load we can't handle.
492    const SCEVAddRecExpr *LoadEv =
493        dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
494    if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
495      return LegalStoreKind::None;
496
497    // The store and load must share the same stride.
498    if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
499      return LegalStoreKind::None;
500
501    // Success.  This store can be converted into a memcpy.
502    UnorderedAtomic = UnorderedAtomic || LI->isAtomic();
503    return UnorderedAtomic ? LegalStoreKind::UnorderedAtomicMemcpy
504                           : LegalStoreKind::Memcpy;
505  }
506  // This store can't be transformed into a memset/memcpy.
507  return LegalStoreKind::None;
508}
509
510void LoopIdiomRecognize::collectStores(BasicBlock *BB) {
511  StoreRefsForMemset.clear();
512  StoreRefsForMemsetPattern.clear();
513  StoreRefsForMemcpy.clear();
514  for (Instruction &I : *BB) {
515    StoreInst *SI = dyn_cast<StoreInst>(&I);
516    if (!SI)
517      continue;
518
519    // Make sure this is a strided store with a constant stride.
520    switch (isLegalStore(SI)) {
521    case LegalStoreKind::None:
522      // Nothing to do
523      break;
524    case LegalStoreKind::Memset: {
525      // Find the base pointer.
526      Value *Ptr = getUnderlyingObject(SI->getPointerOperand());
527      StoreRefsForMemset[Ptr].push_back(SI);
528    } break;
529    case LegalStoreKind::MemsetPattern: {
530      // Find the base pointer.
531      Value *Ptr = getUnderlyingObject(SI->getPointerOperand());
532      StoreRefsForMemsetPattern[Ptr].push_back(SI);
533    } break;
534    case LegalStoreKind::Memcpy:
535    case LegalStoreKind::UnorderedAtomicMemcpy:
536      StoreRefsForMemcpy.push_back(SI);
537      break;
538    default:
539      assert(false && "unhandled return value");
540      break;
541    }
542  }
543}
544
545/// runOnLoopBlock - Process the specified block, which lives in a counted loop
546/// with the specified backedge count.  This block is known to be in the current
547/// loop and not in any subloops.
548bool LoopIdiomRecognize::runOnLoopBlock(
549    BasicBlock *BB, const SCEV *BECount,
550    SmallVectorImpl<BasicBlock *> &ExitBlocks) {
551  // We can only promote stores in this block if they are unconditionally
552  // executed in the loop.  For a block to be unconditionally executed, it has
553  // to dominate all the exit blocks of the loop.  Verify this now.
554  for (BasicBlock *ExitBlock : ExitBlocks)
555    if (!DT->dominates(BB, ExitBlock))
556      return false;
557
558  bool MadeChange = false;
559  // Look for store instructions, which may be optimized to memset/memcpy.
560  collectStores(BB);
561
562  // Look for a single store or sets of stores with a common base, which can be
563  // optimized into a memset (memset_pattern).  The latter most commonly happens
564  // with structs and handunrolled loops.
565  for (auto &SL : StoreRefsForMemset)
566    MadeChange |= processLoopStores(SL.second, BECount, ForMemset::Yes);
567
568  for (auto &SL : StoreRefsForMemsetPattern)
569    MadeChange |= processLoopStores(SL.second, BECount, ForMemset::No);
570
571  // Optimize the store into a memcpy, if it feeds an similarly strided load.
572  for (auto &SI : StoreRefsForMemcpy)
573    MadeChange |= processLoopStoreOfLoopLoad(SI, BECount);
574
575  MadeChange |= processLoopMemIntrinsic<MemCpyInst>(
576      BB, &LoopIdiomRecognize::processLoopMemCpy, BECount);
577  MadeChange |= processLoopMemIntrinsic<MemSetInst>(
578      BB, &LoopIdiomRecognize::processLoopMemSet, BECount);
579
580  return MadeChange;
581}
582
583/// See if this store(s) can be promoted to a memset.
584bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
585                                           const SCEV *BECount, ForMemset For) {
586  // Try to find consecutive stores that can be transformed into memsets.
587  SetVector<StoreInst *> Heads, Tails;
588  SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain;
589
590  // Do a quadratic search on all of the given stores and find
591  // all of the pairs of stores that follow each other.
592  SmallVector<unsigned, 16> IndexQueue;
593  for (unsigned i = 0, e = SL.size(); i < e; ++i) {
594    assert(SL[i]->isSimple() && "Expected only non-volatile stores.");
595
596    Value *FirstStoredVal = SL[i]->getValueOperand();
597    Value *FirstStorePtr = SL[i]->getPointerOperand();
598    const SCEVAddRecExpr *FirstStoreEv =
599        cast<SCEVAddRecExpr>(SE->getSCEV(FirstStorePtr));
600    APInt FirstStride = getStoreStride(FirstStoreEv);
601    unsigned FirstStoreSize = DL->getTypeStoreSize(SL[i]->getValueOperand()->getType());
602
603    // See if we can optimize just this store in isolation.
604    if (FirstStride == FirstStoreSize || -FirstStride == FirstStoreSize) {
605      Heads.insert(SL[i]);
606      continue;
607    }
608
609    Value *FirstSplatValue = nullptr;
610    Constant *FirstPatternValue = nullptr;
611
612    if (For == ForMemset::Yes)
613      FirstSplatValue = isBytewiseValue(FirstStoredVal, *DL);
614    else
615      FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL);
616
617    assert((FirstSplatValue || FirstPatternValue) &&
618           "Expected either splat value or pattern value.");
619
620    IndexQueue.clear();
621    // If a store has multiple consecutive store candidates, search Stores
622    // array according to the sequence: from i+1 to e, then from i-1 to 0.
623    // This is because usually pairing with immediate succeeding or preceding
624    // candidate create the best chance to find memset opportunity.
625    unsigned j = 0;
626    for (j = i + 1; j < e; ++j)
627      IndexQueue.push_back(j);
628    for (j = i; j > 0; --j)
629      IndexQueue.push_back(j - 1);
630
631    for (auto &k : IndexQueue) {
632      assert(SL[k]->isSimple() && "Expected only non-volatile stores.");
633      Value *SecondStorePtr = SL[k]->getPointerOperand();
634      const SCEVAddRecExpr *SecondStoreEv =
635          cast<SCEVAddRecExpr>(SE->getSCEV(SecondStorePtr));
636      APInt SecondStride = getStoreStride(SecondStoreEv);
637
638      if (FirstStride != SecondStride)
639        continue;
640
641      Value *SecondStoredVal = SL[k]->getValueOperand();
642      Value *SecondSplatValue = nullptr;
643      Constant *SecondPatternValue = nullptr;
644
645      if (For == ForMemset::Yes)
646        SecondSplatValue = isBytewiseValue(SecondStoredVal, *DL);
647      else
648        SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL);
649
650      assert((SecondSplatValue || SecondPatternValue) &&
651             "Expected either splat value or pattern value.");
652
653      if (isConsecutiveAccess(SL[i], SL[k], *DL, *SE, false)) {
654        if (For == ForMemset::Yes) {
655          if (isa<UndefValue>(FirstSplatValue))
656            FirstSplatValue = SecondSplatValue;
657          if (FirstSplatValue != SecondSplatValue)
658            continue;
659        } else {
660          if (isa<UndefValue>(FirstPatternValue))
661            FirstPatternValue = SecondPatternValue;
662          if (FirstPatternValue != SecondPatternValue)
663            continue;
664        }
665        Tails.insert(SL[k]);
666        Heads.insert(SL[i]);
667        ConsecutiveChain[SL[i]] = SL[k];
668        break;
669      }
670    }
671  }
672
673  // We may run into multiple chains that merge into a single chain. We mark the
674  // stores that we transformed so that we don't visit the same store twice.
675  SmallPtrSet<Value *, 16> TransformedStores;
676  bool Changed = false;
677
678  // For stores that start but don't end a link in the chain:
679  for (StoreInst *I : Heads) {
680    if (Tails.count(I))
681      continue;
682
683    // We found a store instr that starts a chain. Now follow the chain and try
684    // to transform it.
685    SmallPtrSet<Instruction *, 8> AdjacentStores;
686    StoreInst *HeadStore = I;
687    unsigned StoreSize = 0;
688
689    // Collect the chain into a list.
690    while (Tails.count(I) || Heads.count(I)) {
691      if (TransformedStores.count(I))
692        break;
693      AdjacentStores.insert(I);
694
695      StoreSize += DL->getTypeStoreSize(I->getValueOperand()->getType());
696      // Move to the next value in the chain.
697      I = ConsecutiveChain[I];
698    }
699
700    Value *StoredVal = HeadStore->getValueOperand();
701    Value *StorePtr = HeadStore->getPointerOperand();
702    const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
703    APInt Stride = getStoreStride(StoreEv);
704
705    // Check to see if the stride matches the size of the stores.  If so, then
706    // we know that every byte is touched in the loop.
707    if (StoreSize != Stride && StoreSize != -Stride)
708      continue;
709
710    bool IsNegStride = StoreSize == -Stride;
711
712    Type *IntIdxTy = DL->getIndexType(StorePtr->getType());
713    const SCEV *StoreSizeSCEV = SE->getConstant(IntIdxTy, StoreSize);
714    if (processLoopStridedStore(StorePtr, StoreSizeSCEV,
715                                MaybeAlign(HeadStore->getAlign()), StoredVal,
716                                HeadStore, AdjacentStores, StoreEv, BECount,
717                                IsNegStride)) {
718      TransformedStores.insert(AdjacentStores.begin(), AdjacentStores.end());
719      Changed = true;
720    }
721  }
722
723  return Changed;
724}
725
726/// processLoopMemIntrinsic - Template function for calling different processor
727/// functions based on mem intrinsic type.
728template <typename MemInst>
729bool LoopIdiomRecognize::processLoopMemIntrinsic(
730    BasicBlock *BB,
731    bool (LoopIdiomRecognize::*Processor)(MemInst *, const SCEV *),
732    const SCEV *BECount) {
733  bool MadeChange = false;
734  for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
735    Instruction *Inst = &*I++;
736    // Look for memory instructions, which may be optimized to a larger one.
737    if (MemInst *MI = dyn_cast<MemInst>(Inst)) {
738      WeakTrackingVH InstPtr(&*I);
739      if (!(this->*Processor)(MI, BECount))
740        continue;
741      MadeChange = true;
742
743      // If processing the instruction invalidated our iterator, start over from
744      // the top of the block.
745      if (!InstPtr)
746        I = BB->begin();
747    }
748  }
749  return MadeChange;
750}
751
752/// processLoopMemCpy - See if this memcpy can be promoted to a large memcpy
753bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
754                                           const SCEV *BECount) {
755  // We can only handle non-volatile memcpys with a constant size.
756  if (MCI->isVolatile() || !isa<ConstantInt>(MCI->getLength()))
757    return false;
758
759  // If we're not allowed to hack on memcpy, we fail.
760  if ((!HasMemcpy && !isa<MemCpyInlineInst>(MCI)) || DisableLIRP::Memcpy)
761    return false;
762
763  Value *Dest = MCI->getDest();
764  Value *Source = MCI->getSource();
765  if (!Dest || !Source)
766    return false;
767
768  // See if the load and store pointer expressions are AddRec like {base,+,1} on
769  // the current loop, which indicates a strided load and store.  If we have
770  // something else, it's a random load or store we can't handle.
771  const SCEVAddRecExpr *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Dest));
772  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
773    return false;
774  const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Source));
775  if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
776    return false;
777
778  // Reject memcpys that are so large that they overflow an unsigned.
779  uint64_t SizeInBytes = cast<ConstantInt>(MCI->getLength())->getZExtValue();
780  if ((SizeInBytes >> 32) != 0)
781    return false;
782
783  // Check if the stride matches the size of the memcpy. If so, then we know
784  // that every byte is touched in the loop.
785  const SCEVConstant *ConstStoreStride =
786      dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
787  const SCEVConstant *ConstLoadStride =
788      dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
789  if (!ConstStoreStride || !ConstLoadStride)
790    return false;
791
792  APInt StoreStrideValue = ConstStoreStride->getAPInt();
793  APInt LoadStrideValue = ConstLoadStride->getAPInt();
794  // Huge stride value - give up
795  if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
796    return false;
797
798  if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
799    ORE.emit([&]() {
800      return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
801             << ore::NV("Inst", "memcpy") << " in "
802             << ore::NV("Function", MCI->getFunction())
803             << " function will not be hoisted: "
804             << ore::NV("Reason", "memcpy size is not equal to stride");
805    });
806    return false;
807  }
808
809  int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
810  int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
811  // Check if the load stride matches the store stride.
812  if (StoreStrideInt != LoadStrideInt)
813    return false;
814
815  return processLoopStoreOfLoopLoad(
816      Dest, Source, SE->getConstant(Dest->getType(), SizeInBytes),
817      MCI->getDestAlign(), MCI->getSourceAlign(), MCI, MCI, StoreEv, LoadEv,
818      BECount);
819}
820
821/// processLoopMemSet - See if this memset can be promoted to a large memset.
822bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
823                                           const SCEV *BECount) {
824  // We can only handle non-volatile memsets.
825  if (MSI->isVolatile())
826    return false;
827
828  // If we're not allowed to hack on memset, we fail.
829  if (!HasMemset || DisableLIRP::Memset)
830    return false;
831
832  Value *Pointer = MSI->getDest();
833
834  // See if the pointer expression is an AddRec like {base,+,1} on the current
835  // loop, which indicates a strided store.  If we have something else, it's a
836  // random store we can't handle.
837  const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer));
838  if (!Ev || Ev->getLoop() != CurLoop)
839    return false;
840  if (!Ev->isAffine()) {
841    LLVM_DEBUG(dbgs() << "  Pointer is not affine, abort\n");
842    return false;
843  }
844
845  const SCEV *PointerStrideSCEV = Ev->getOperand(1);
846  const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
847  if (!PointerStrideSCEV || !MemsetSizeSCEV)
848    return false;
849
850  bool IsNegStride = false;
851  const bool IsConstantSize = isa<ConstantInt>(MSI->getLength());
852
853  if (IsConstantSize) {
854    // Memset size is constant.
855    // Check if the pointer stride matches the memset size. If so, then
856    // we know that every byte is touched in the loop.
857    LLVM_DEBUG(dbgs() << "  memset size is constant\n");
858    uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
859    const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
860    if (!ConstStride)
861      return false;
862
863    APInt Stride = ConstStride->getAPInt();
864    if (SizeInBytes != Stride && SizeInBytes != -Stride)
865      return false;
866
867    IsNegStride = SizeInBytes == -Stride;
868  } else {
869    // Memset size is non-constant.
870    // Check if the pointer stride matches the memset size.
871    // To be conservative, the pass would not promote pointers that aren't in
872    // address space zero. Also, the pass only handles memset length and stride
873    // that are invariant for the top level loop.
874    LLVM_DEBUG(dbgs() << "  memset size is non-constant\n");
875    if (Pointer->getType()->getPointerAddressSpace() != 0) {
876      LLVM_DEBUG(dbgs() << "  pointer is not in address space zero, "
877                        << "abort\n");
878      return false;
879    }
880    if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) {
881      LLVM_DEBUG(dbgs() << "  memset size is not a loop-invariant, "
882                        << "abort\n");
883      return false;
884    }
885
886    // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV
887    IsNegStride = PointerStrideSCEV->isNonConstantNegative();
888    const SCEV *PositiveStrideSCEV =
889        IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV)
890                    : PointerStrideSCEV;
891    LLVM_DEBUG(dbgs() << "  MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n"
892                      << "  PositiveStrideSCEV: " << *PositiveStrideSCEV
893                      << "\n");
894
895    if (PositiveStrideSCEV != MemsetSizeSCEV) {
896      // If an expression is covered by the loop guard, compare again and
897      // proceed with optimization if equal.
898      const SCEV *FoldedPositiveStride =
899          SE->applyLoopGuards(PositiveStrideSCEV, CurLoop);
900      const SCEV *FoldedMemsetSize =
901          SE->applyLoopGuards(MemsetSizeSCEV, CurLoop);
902
903      LLVM_DEBUG(dbgs() << "  Try to fold SCEV based on loop guard\n"
904                        << "    FoldedMemsetSize: " << *FoldedMemsetSize << "\n"
905                        << "    FoldedPositiveStride: " << *FoldedPositiveStride
906                        << "\n");
907
908      if (FoldedPositiveStride != FoldedMemsetSize) {
909        LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
910        return false;
911      }
912    }
913  }
914
915  // Verify that the memset value is loop invariant.  If not, we can't promote
916  // the memset.
917  Value *SplatValue = MSI->getValue();
918  if (!SplatValue || !CurLoop->isLoopInvariant(SplatValue))
919    return false;
920
921  SmallPtrSet<Instruction *, 1> MSIs;
922  MSIs.insert(MSI);
923  return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
924                                 MSI->getDestAlign(), SplatValue, MSI, MSIs, Ev,
925                                 BECount, IsNegStride, /*IsLoopMemset=*/true);
926}
927
928/// mayLoopAccessLocation - Return true if the specified loop might access the
929/// specified pointer location, which is a loop-strided access.  The 'Access'
930/// argument specifies what the verboten forms of access are (read or write).
931static bool
932mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
933                      const SCEV *BECount, const SCEV *StoreSizeSCEV,
934                      AliasAnalysis &AA,
935                      SmallPtrSetImpl<Instruction *> &IgnoredInsts) {
936  // Get the location that may be stored across the loop.  Since the access is
937  // strided positively through memory, we say that the modified location starts
938  // at the pointer and has infinite size.
939  LocationSize AccessSize = LocationSize::afterPointer();
940
941  // If the loop iterates a fixed number of times, we can refine the access size
942  // to be exactly the size of the memset, which is (BECount+1)*StoreSize
943  const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
944  const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
945  if (BECst && ConstSize) {
946    std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
947    std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
948    // FIXME: Should this check for overflow?
949    if (BEInt && SizeInt)
950      AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
951  }
952
953  // TODO: For this to be really effective, we have to dive into the pointer
954  // operand in the store.  Store to &A[i] of 100 will always return may alias
955  // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
956  // which will then no-alias a store to &A[100].
957  MemoryLocation StoreLoc(Ptr, AccessSize);
958
959  for (BasicBlock *B : L->blocks())
960    for (Instruction &I : *B)
961      if (!IgnoredInsts.contains(&I) &&
962          isModOrRefSet(AA.getModRefInfo(&I, StoreLoc) & Access))
963        return true;
964  return false;
965}
966
967// If we have a negative stride, Start refers to the end of the memory location
968// we're trying to memset.  Therefore, we need to recompute the base pointer,
969// which is just Start - BECount*Size.
970static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
971                                        Type *IntPtr, const SCEV *StoreSizeSCEV,
972                                        ScalarEvolution *SE) {
973  const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
974  if (!StoreSizeSCEV->isOne()) {
975    // index = back edge count * store size
976    Index = SE->getMulExpr(Index,
977                           SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
978                           SCEV::FlagNUW);
979  }
980  // base pointer = start - index * store size
981  return SE->getMinusSCEV(Start, Index);
982}
983
984/// Compute the number of bytes as a SCEV from the backedge taken count.
985///
986/// This also maps the SCEV into the provided type and tries to handle the
987/// computation in a way that will fold cleanly.
988static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
989                               const SCEV *StoreSizeSCEV, Loop *CurLoop,
990                               const DataLayout *DL, ScalarEvolution *SE) {
991  const SCEV *TripCountSCEV =
992      SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop);
993  return SE->getMulExpr(TripCountSCEV,
994                        SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
995                        SCEV::FlagNUW);
996}
997
998/// processLoopStridedStore - We see a strided store of some value.  If we can
999/// transform this into a memset or memset_pattern in the loop preheader, do so.
1000bool LoopIdiomRecognize::processLoopStridedStore(
1001    Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
1002    Value *StoredVal, Instruction *TheStore,
1003    SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
1004    const SCEV *BECount, bool IsNegStride, bool IsLoopMemset) {
1005  Module *M = TheStore->getModule();
1006  Value *SplatValue = isBytewiseValue(StoredVal, *DL);
1007  Constant *PatternValue = nullptr;
1008
1009  if (!SplatValue)
1010    PatternValue = getMemSetPatternValue(StoredVal, DL);
1011
1012  assert((SplatValue || PatternValue) &&
1013         "Expected either splat value or pattern value.");
1014
1015  // The trip count of the loop and the base pointer of the addrec SCEV is
1016  // guaranteed to be loop invariant, which means that it should dominate the
1017  // header.  This allows us to insert code for it in the preheader.
1018  unsigned DestAS = DestPtr->getType()->getPointerAddressSpace();
1019  BasicBlock *Preheader = CurLoop->getLoopPreheader();
1020  IRBuilder<> Builder(Preheader->getTerminator());
1021  SCEVExpander Expander(*SE, *DL, "loop-idiom");
1022  SCEVExpanderCleaner ExpCleaner(Expander);
1023
1024  Type *DestInt8PtrTy = Builder.getPtrTy(DestAS);
1025  Type *IntIdxTy = DL->getIndexType(DestPtr->getType());
1026
1027  bool Changed = false;
1028  const SCEV *Start = Ev->getStart();
1029  // Handle negative strided loops.
1030  if (IsNegStride)
1031    Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE);
1032
1033  // TODO: ideally we should still be able to generate memset if SCEV expander
1034  // is taught to generate the dependencies at the latest point.
1035  if (!Expander.isSafeToExpand(Start))
1036    return Changed;
1037
1038  // Okay, we have a strided store "p[i]" of a splattable value.  We can turn
1039  // this into a memset in the loop preheader now if we want.  However, this
1040  // would be unsafe to do if there is anything else in the loop that may read
1041  // or write to the aliased location.  Check for any overlap by generating the
1042  // base pointer and checking the region.
1043  Value *BasePtr =
1044      Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator());
1045
1046  // From here on out, conservatively report to the pass manager that we've
1047  // changed the IR, even if we later clean up these added instructions. There
1048  // may be structural differences e.g. in the order of use lists not accounted
1049  // for in just a textual dump of the IR. This is written as a variable, even
1050  // though statically all the places this dominates could be replaced with
1051  // 'true', with the hope that anyone trying to be clever / "more precise" with
1052  // the return value will read this comment, and leave them alone.
1053  Changed = true;
1054
1055  if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1056                            StoreSizeSCEV, *AA, Stores))
1057    return Changed;
1058
1059  if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
1060    return Changed;
1061
1062  // Okay, everything looks good, insert the memset.
1063
1064  const SCEV *NumBytesS =
1065      getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1066
1067  // TODO: ideally we should still be able to generate memset if SCEV expander
1068  // is taught to generate the dependencies at the latest point.
1069  if (!Expander.isSafeToExpand(NumBytesS))
1070    return Changed;
1071
1072  Value *NumBytes =
1073      Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
1074
1075  if (!SplatValue && !isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16))
1076    return Changed;
1077
1078  AAMDNodes AATags = TheStore->getAAMetadata();
1079  for (Instruction *Store : Stores)
1080    AATags = AATags.merge(Store->getAAMetadata());
1081  if (auto CI = dyn_cast<ConstantInt>(NumBytes))
1082    AATags = AATags.extendTo(CI->getZExtValue());
1083  else
1084    AATags = AATags.extendTo(-1);
1085
1086  CallInst *NewCall;
1087  if (SplatValue) {
1088    NewCall = Builder.CreateMemSet(
1089        BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment),
1090        /*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias);
1091  } else {
1092    assert (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16));
1093    // Everything is emitted in default address space
1094    Type *Int8PtrTy = DestInt8PtrTy;
1095
1096    StringRef FuncName = "memset_pattern16";
1097    FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16,
1098                            Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntIdxTy);
1099    inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI);
1100
1101    // Otherwise we should form a memset_pattern16.  PatternValue is known to be
1102    // an constant array of 16-bytes.  Plop the value into a mergable global.
1103    GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true,
1104                                            GlobalValue::PrivateLinkage,
1105                                            PatternValue, ".memset_pattern");
1106    GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these.
1107    GV->setAlignment(Align(16));
1108    Value *PatternPtr = GV;
1109    NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});
1110
1111    // Set the TBAA info if present.
1112    if (AATags.TBAA)
1113      NewCall->setMetadata(LLVMContext::MD_tbaa, AATags.TBAA);
1114
1115    if (AATags.Scope)
1116      NewCall->setMetadata(LLVMContext::MD_alias_scope, AATags.Scope);
1117
1118    if (AATags.NoAlias)
1119      NewCall->setMetadata(LLVMContext::MD_noalias, AATags.NoAlias);
1120  }
1121
1122  NewCall->setDebugLoc(TheStore->getDebugLoc());
1123
1124  if (MSSAU) {
1125    MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
1126        NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator);
1127    MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true);
1128  }
1129
1130  LLVM_DEBUG(dbgs() << "  Formed memset: " << *NewCall << "\n"
1131                    << "    from store to: " << *Ev << " at: " << *TheStore
1132                    << "\n");
1133
1134  ORE.emit([&]() {
1135    OptimizationRemark R(DEBUG_TYPE, "ProcessLoopStridedStore",
1136                         NewCall->getDebugLoc(), Preheader);
1137    R << "Transformed loop-strided store in "
1138      << ore::NV("Function", TheStore->getFunction())
1139      << " function into a call to "
1140      << ore::NV("NewFunction", NewCall->getCalledFunction())
1141      << "() intrinsic";
1142    if (!Stores.empty())
1143      R << ore::setExtraArgs();
1144    for (auto *I : Stores) {
1145      R << ore::NV("FromBlock", I->getParent()->getName())
1146        << ore::NV("ToBlock", Preheader->getName());
1147    }
1148    return R;
1149  });
1150
1151  // Okay, the memset has been formed.  Zap the original store and anything that
1152  // feeds into it.
1153  for (auto *I : Stores) {
1154    if (MSSAU)
1155      MSSAU->removeMemoryAccess(I, true);
1156    deleteDeadInstruction(I);
1157  }
1158  if (MSSAU && VerifyMemorySSA)
1159    MSSAU->getMemorySSA()->verifyMemorySSA();
1160  ++NumMemSet;
1161  ExpCleaner.markResultUsed();
1162  return true;
1163}
1164
1165/// If the stored value is a strided load in the same loop with the same stride
1166/// this may be transformable into a memcpy.  This kicks in for stuff like
1167/// for (i) A[i] = B[i];
1168bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
1169                                                    const SCEV *BECount) {
1170  assert(SI->isUnordered() && "Expected only non-volatile non-ordered stores.");
1171
1172  Value *StorePtr = SI->getPointerOperand();
1173  const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
1174  unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
1175
1176  // The store must be feeding a non-volatile load.
1177  LoadInst *LI = cast<LoadInst>(SI->getValueOperand());
1178  assert(LI->isUnordered() && "Expected only non-volatile non-ordered loads.");
1179
1180  // See if the pointer expression is an AddRec like {base,+,1} on the current
1181  // loop, which indicates a strided load.  If we have something else, it's a
1182  // random load we can't handle.
1183  Value *LoadPtr = LI->getPointerOperand();
1184  const SCEVAddRecExpr *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr));
1185
1186  const SCEV *StoreSizeSCEV = SE->getConstant(StorePtr->getType(), StoreSize);
1187  return processLoopStoreOfLoopLoad(StorePtr, LoadPtr, StoreSizeSCEV,
1188                                    SI->getAlign(), LI->getAlign(), SI, LI,
1189                                    StoreEv, LoadEv, BECount);
1190}
1191
1192namespace {
1193class MemmoveVerifier {
1194public:
1195  explicit MemmoveVerifier(const Value &LoadBasePtr, const Value &StoreBasePtr,
1196                           const DataLayout &DL)
1197      : DL(DL), BP1(llvm::GetPointerBaseWithConstantOffset(
1198                    LoadBasePtr.stripPointerCasts(), LoadOff, DL)),
1199        BP2(llvm::GetPointerBaseWithConstantOffset(
1200            StoreBasePtr.stripPointerCasts(), StoreOff, DL)),
1201        IsSameObject(BP1 == BP2) {}
1202
1203  bool loadAndStoreMayFormMemmove(unsigned StoreSize, bool IsNegStride,
1204                                  const Instruction &TheLoad,
1205                                  bool IsMemCpy) const {
1206    if (IsMemCpy) {
1207      // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr
1208      // for negative stride.
1209      if ((!IsNegStride && LoadOff <= StoreOff) ||
1210          (IsNegStride && LoadOff >= StoreOff))
1211        return false;
1212    } else {
1213      // Ensure that LoadBasePtr is after StoreBasePtr or before StoreBasePtr
1214      // for negative stride. LoadBasePtr shouldn't overlap with StoreBasePtr.
1215      int64_t LoadSize =
1216          DL.getTypeSizeInBits(TheLoad.getType()).getFixedValue() / 8;
1217      if (BP1 != BP2 || LoadSize != int64_t(StoreSize))
1218        return false;
1219      if ((!IsNegStride && LoadOff < StoreOff + int64_t(StoreSize)) ||
1220          (IsNegStride && LoadOff + LoadSize > StoreOff))
1221        return false;
1222    }
1223    return true;
1224  }
1225
1226private:
1227  const DataLayout &DL;
1228  int64_t LoadOff = 0;
1229  int64_t StoreOff = 0;
1230  const Value *BP1;
1231  const Value *BP2;
1232
1233public:
1234  const bool IsSameObject;
1235};
1236} // namespace
1237
1238bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1239    Value *DestPtr, Value *SourcePtr, const SCEV *StoreSizeSCEV,
1240    MaybeAlign StoreAlign, MaybeAlign LoadAlign, Instruction *TheStore,
1241    Instruction *TheLoad, const SCEVAddRecExpr *StoreEv,
1242    const SCEVAddRecExpr *LoadEv, const SCEV *BECount) {
1243
1244  // FIXME: until llvm.memcpy.inline supports dynamic sizes, we need to
1245  // conservatively bail here, since otherwise we may have to transform
1246  // llvm.memcpy.inline into llvm.memcpy which is illegal.
1247  if (isa<MemCpyInlineInst>(TheStore))
1248    return false;
1249
1250  // The trip count of the loop and the base pointer of the addrec SCEV is
1251  // guaranteed to be loop invariant, which means that it should dominate the
1252  // header.  This allows us to insert code for it in the preheader.
1253  BasicBlock *Preheader = CurLoop->getLoopPreheader();
1254  IRBuilder<> Builder(Preheader->getTerminator());
1255  SCEVExpander Expander(*SE, *DL, "loop-idiom");
1256
1257  SCEVExpanderCleaner ExpCleaner(Expander);
1258
1259  bool Changed = false;
1260  const SCEV *StrStart = StoreEv->getStart();
1261  unsigned StrAS = DestPtr->getType()->getPointerAddressSpace();
1262  Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS));
1263
1264  APInt Stride = getStoreStride(StoreEv);
1265  const SCEVConstant *ConstStoreSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
1266
1267  // TODO: Deal with non-constant size; Currently expect constant store size
1268  assert(ConstStoreSize && "store size is expected to be a constant");
1269
1270  int64_t StoreSize = ConstStoreSize->getValue()->getZExtValue();
1271  bool IsNegStride = StoreSize == -Stride;
1272
1273  // Handle negative strided loops.
1274  if (IsNegStride)
1275    StrStart =
1276        getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
1277
1278  // Okay, we have a strided store "p[i]" of a loaded value.  We can turn
1279  // this into a memcpy in the loop preheader now if we want.  However, this
1280  // would be unsafe to do if there is anything else in the loop that may read
1281  // or write the memory region we're storing to.  This includes the load that
1282  // feeds the stores.  Check for an alias by generating the base address and
1283  // checking everything.
1284  Value *StoreBasePtr = Expander.expandCodeFor(
1285      StrStart, Builder.getPtrTy(StrAS), Preheader->getTerminator());
1286
1287  // From here on out, conservatively report to the pass manager that we've
1288  // changed the IR, even if we later clean up these added instructions. There
1289  // may be structural differences e.g. in the order of use lists not accounted
1290  // for in just a textual dump of the IR. This is written as a variable, even
1291  // though statically all the places this dominates could be replaced with
1292  // 'true', with the hope that anyone trying to be clever / "more precise" with
1293  // the return value will read this comment, and leave them alone.
1294  Changed = true;
1295
1296  SmallPtrSet<Instruction *, 2> IgnoredInsts;
1297  IgnoredInsts.insert(TheStore);
1298
1299  bool IsMemCpy = isa<MemCpyInst>(TheStore);
1300  const StringRef InstRemark = IsMemCpy ? "memcpy" : "load and store";
1301
1302  bool LoopAccessStore =
1303      mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1304                            StoreSizeSCEV, *AA, IgnoredInsts);
1305  if (LoopAccessStore) {
1306    // For memmove case it's not enough to guarantee that loop doesn't access
1307    // TheStore and TheLoad. Additionally we need to make sure that TheStore is
1308    // the only user of TheLoad.
1309    if (!TheLoad->hasOneUse())
1310      return Changed;
1311    IgnoredInsts.insert(TheLoad);
1312    if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
1313                              BECount, StoreSizeSCEV, *AA, IgnoredInsts)) {
1314      ORE.emit([&]() {
1315        return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
1316                                        TheStore)
1317               << ore::NV("Inst", InstRemark) << " in "
1318               << ore::NV("Function", TheStore->getFunction())
1319               << " function will not be hoisted: "
1320               << ore::NV("Reason", "The loop may access store location");
1321      });
1322      return Changed;
1323    }
1324    IgnoredInsts.erase(TheLoad);
1325  }
1326
1327  const SCEV *LdStart = LoadEv->getStart();
1328  unsigned LdAS = SourcePtr->getType()->getPointerAddressSpace();
1329
1330  // Handle negative strided loops.
1331  if (IsNegStride)
1332    LdStart =
1333        getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
1334
1335  // For a memcpy, we have to make sure that the input array is not being
1336  // mutated by the loop.
1337  Value *LoadBasePtr = Expander.expandCodeFor(LdStart, Builder.getPtrTy(LdAS),
1338                                              Preheader->getTerminator());
1339
1340  // If the store is a memcpy instruction, we must check if it will write to
1341  // the load memory locations. So remove it from the ignored stores.
1342  MemmoveVerifier Verifier(*LoadBasePtr, *StoreBasePtr, *DL);
1343  if (IsMemCpy && !Verifier.IsSameObject)
1344    IgnoredInsts.erase(TheStore);
1345  if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
1346                            StoreSizeSCEV, *AA, IgnoredInsts)) {
1347    ORE.emit([&]() {
1348      return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
1349             << ore::NV("Inst", InstRemark) << " in "
1350             << ore::NV("Function", TheStore->getFunction())
1351             << " function will not be hoisted: "
1352             << ore::NV("Reason", "The loop may access load location");
1353    });
1354    return Changed;
1355  }
1356
1357  bool UseMemMove = IsMemCpy ? Verifier.IsSameObject : LoopAccessStore;
1358  if (UseMemMove)
1359    if (!Verifier.loadAndStoreMayFormMemmove(StoreSize, IsNegStride, *TheLoad,
1360                                             IsMemCpy))
1361      return Changed;
1362
1363  if (avoidLIRForMultiBlockLoop())
1364    return Changed;
1365
1366  // Okay, everything is safe, we can transform this!
1367
1368  const SCEV *NumBytesS =
1369      getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
1370
1371  Value *NumBytes =
1372      Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator());
1373
1374  AAMDNodes AATags = TheLoad->getAAMetadata();
1375  AAMDNodes StoreAATags = TheStore->getAAMetadata();
1376  AATags = AATags.merge(StoreAATags);
1377  if (auto CI = dyn_cast<ConstantInt>(NumBytes))
1378    AATags = AATags.extendTo(CI->getZExtValue());
1379  else
1380    AATags = AATags.extendTo(-1);
1381
1382  CallInst *NewCall = nullptr;
1383  // Check whether to generate an unordered atomic memcpy:
1384  //  If the load or store are atomic, then they must necessarily be unordered
1385  //  by previous checks.
1386  if (!TheStore->isAtomic() && !TheLoad->isAtomic()) {
1387    if (UseMemMove)
1388      NewCall = Builder.CreateMemMove(
1389          StoreBasePtr, StoreAlign, LoadBasePtr, LoadAlign, NumBytes,
1390          /*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias);
1391    else
1392      NewCall =
1393          Builder.CreateMemCpy(StoreBasePtr, StoreAlign, LoadBasePtr, LoadAlign,
1394                               NumBytes, /*isVolatile=*/false, AATags.TBAA,
1395                               AATags.TBAAStruct, AATags.Scope, AATags.NoAlias);
1396  } else {
1397    // For now don't support unordered atomic memmove.
1398    if (UseMemMove)
1399      return Changed;
1400    // We cannot allow unaligned ops for unordered load/store, so reject
1401    // anything where the alignment isn't at least the element size.
1402    assert((StoreAlign && LoadAlign) &&
1403           "Expect unordered load/store to have align.");
1404    if (*StoreAlign < StoreSize || *LoadAlign < StoreSize)
1405      return Changed;
1406
1407    // If the element.atomic memcpy is not lowered into explicit
1408    // loads/stores later, then it will be lowered into an element-size
1409    // specific lib call. If the lib call doesn't exist for our store size, then
1410    // we shouldn't generate the memcpy.
1411    if (StoreSize > TTI->getAtomicMemIntrinsicMaxElementSize())
1412      return Changed;
1413
1414    // Create the call.
1415    // Note that unordered atomic loads/stores are *required* by the spec to
1416    // have an alignment but non-atomic loads/stores may not.
1417    NewCall = Builder.CreateElementUnorderedAtomicMemCpy(
1418        StoreBasePtr, *StoreAlign, LoadBasePtr, *LoadAlign, NumBytes, StoreSize,
1419        AATags.TBAA, AATags.TBAAStruct, AATags.Scope, AATags.NoAlias);
1420  }
1421  NewCall->setDebugLoc(TheStore->getDebugLoc());
1422
1423  if (MSSAU) {
1424    MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB(
1425        NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator);
1426    MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true);
1427  }
1428
1429  LLVM_DEBUG(dbgs() << "  Formed new call: " << *NewCall << "\n"
1430                    << "    from load ptr=" << *LoadEv << " at: " << *TheLoad
1431                    << "\n"
1432                    << "    from store ptr=" << *StoreEv << " at: " << *TheStore
1433                    << "\n");
1434
1435  ORE.emit([&]() {
1436    return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStoreOfLoopLoad",
1437                              NewCall->getDebugLoc(), Preheader)
1438           << "Formed a call to "
1439           << ore::NV("NewFunction", NewCall->getCalledFunction())
1440           << "() intrinsic from " << ore::NV("Inst", InstRemark)
1441           << " instruction in " << ore::NV("Function", TheStore->getFunction())
1442           << " function"
1443           << ore::setExtraArgs()
1444           << ore::NV("FromBlock", TheStore->getParent()->getName())
1445           << ore::NV("ToBlock", Preheader->getName());
1446  });
1447
1448  // Okay, a new call to memcpy/memmove has been formed.  Zap the original store
1449  // and anything that feeds into it.
1450  if (MSSAU)
1451    MSSAU->removeMemoryAccess(TheStore, true);
1452  deleteDeadInstruction(TheStore);
1453  if (MSSAU && VerifyMemorySSA)
1454    MSSAU->getMemorySSA()->verifyMemorySSA();
1455  if (UseMemMove)
1456    ++NumMemMove;
1457  else
1458    ++NumMemCpy;
1459  ExpCleaner.markResultUsed();
1460  return true;
1461}
1462
1463// When compiling for codesize we avoid idiom recognition for a multi-block loop
1464// unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop.
1465//
1466bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset,
1467                                                   bool IsLoopMemset) {
1468  if (ApplyCodeSizeHeuristics && CurLoop->getNumBlocks() > 1) {
1469    if (CurLoop->isOutermost() && (!IsMemset || !IsLoopMemset)) {
1470      LLVM_DEBUG(dbgs() << "  " << CurLoop->getHeader()->getParent()->getName()
1471                        << " : LIR " << (IsMemset ? "Memset" : "Memcpy")
1472                        << " avoided: multi-block top-level loop\n");
1473      return true;
1474    }
1475  }
1476
1477  return false;
1478}
1479
1480bool LoopIdiomRecognize::runOnNoncountableLoop() {
1481  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
1482                    << CurLoop->getHeader()->getParent()->getName()
1483                    << "] Noncountable Loop %"
1484                    << CurLoop->getHeader()->getName() << "\n");
1485
1486  return recognizePopcount() || recognizeAndInsertFFS() ||
1487         recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
1488}
1489
1490/// Check if the given conditional branch is based on the comparison between
1491/// a variable and zero, and if the variable is non-zero or zero (JmpOnZero is
1492/// true), the control yields to the loop entry. If the branch matches the
1493/// behavior, the variable involved in the comparison is returned. This function
1494/// will be called to see if the precondition and postcondition of the loop are
1495/// in desirable form.
1496static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry,
1497                             bool JmpOnZero = false) {
1498  if (!BI || !BI->isConditional())
1499    return nullptr;
1500
1501  ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());
1502  if (!Cond)
1503    return nullptr;
1504
1505  ConstantInt *CmpZero = dyn_cast<ConstantInt>(Cond->getOperand(1));
1506  if (!CmpZero || !CmpZero->isZero())
1507    return nullptr;
1508
1509  BasicBlock *TrueSucc = BI->getSuccessor(0);
1510  BasicBlock *FalseSucc = BI->getSuccessor(1);
1511  if (JmpOnZero)
1512    std::swap(TrueSucc, FalseSucc);
1513
1514  ICmpInst::Predicate Pred = Cond->getPredicate();
1515  if ((Pred == ICmpInst::ICMP_NE && TrueSucc == LoopEntry) ||
1516      (Pred == ICmpInst::ICMP_EQ && FalseSucc == LoopEntry))
1517    return Cond->getOperand(0);
1518
1519  return nullptr;
1520}
1521
1522// Check if the recurrence variable `VarX` is in the right form to create
1523// the idiom. Returns the value coerced to a PHINode if so.
1524static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX,
1525                                 BasicBlock *LoopEntry) {
1526  auto *PhiX = dyn_cast<PHINode>(VarX);
1527  if (PhiX && PhiX->getParent() == LoopEntry &&
1528      (PhiX->getOperand(0) == DefX || PhiX->getOperand(1) == DefX))
1529    return PhiX;
1530  return nullptr;
1531}
1532
1533/// Return true iff the idiom is detected in the loop.
1534///
1535/// Additionally:
1536/// 1) \p CntInst is set to the instruction counting the population bit.
1537/// 2) \p CntPhi is set to the corresponding phi node.
1538/// 3) \p Var is set to the value whose population bits are being counted.
1539///
1540/// The core idiom we are trying to detect is:
1541/// \code
1542///    if (x0 != 0)
1543///      goto loop-exit // the precondition of the loop
1544///    cnt0 = init-val;
1545///    do {
1546///       x1 = phi (x0, x2);
1547///       cnt1 = phi(cnt0, cnt2);
1548///
1549///       cnt2 = cnt1 + 1;
1550///        ...
1551///       x2 = x1 & (x1 - 1);
1552///        ...
1553///    } while(x != 0);
1554///
1555/// loop-exit:
1556/// \endcode
1557static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB,
1558                                Instruction *&CntInst, PHINode *&CntPhi,
1559                                Value *&Var) {
1560  // step 1: Check to see if the look-back branch match this pattern:
1561  //    "if (a!=0) goto loop-entry".
1562  BasicBlock *LoopEntry;
1563  Instruction *DefX2, *CountInst;
1564  Value *VarX1, *VarX0;
1565  PHINode *PhiX, *CountPhi;
1566
1567  DefX2 = CountInst = nullptr;
1568  VarX1 = VarX0 = nullptr;
1569  PhiX = CountPhi = nullptr;
1570  LoopEntry = *(CurLoop->block_begin());
1571
1572  // step 1: Check if the loop-back branch is in desirable form.
1573  {
1574    if (Value *T = matchCondition(
1575            dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry))
1576      DefX2 = dyn_cast<Instruction>(T);
1577    else
1578      return false;
1579  }
1580
1581  // step 2: detect instructions corresponding to "x2 = x1 & (x1 - 1)"
1582  {
1583    if (!DefX2 || DefX2->getOpcode() != Instruction::And)
1584      return false;
1585
1586    BinaryOperator *SubOneOp;
1587
1588    if ((SubOneOp = dyn_cast<BinaryOperator>(DefX2->getOperand(0))))
1589      VarX1 = DefX2->getOperand(1);
1590    else {
1591      VarX1 = DefX2->getOperand(0);
1592      SubOneOp = dyn_cast<BinaryOperator>(DefX2->getOperand(1));
1593    }
1594    if (!SubOneOp || SubOneOp->getOperand(0) != VarX1)
1595      return false;
1596
1597    ConstantInt *Dec = dyn_cast<ConstantInt>(SubOneOp->getOperand(1));
1598    if (!Dec ||
1599        !((SubOneOp->getOpcode() == Instruction::Sub && Dec->isOne()) ||
1600          (SubOneOp->getOpcode() == Instruction::Add &&
1601           Dec->isMinusOne()))) {
1602      return false;
1603    }
1604  }
1605
1606  // step 3: Check the recurrence of variable X
1607  PhiX = getRecurrenceVar(VarX1, DefX2, LoopEntry);
1608  if (!PhiX)
1609    return false;
1610
1611  // step 4: Find the instruction which count the population: cnt2 = cnt1 + 1
1612  {
1613    CountInst = nullptr;
1614    for (Instruction &Inst : llvm::make_range(
1615             LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
1616      if (Inst.getOpcode() != Instruction::Add)
1617        continue;
1618
1619      ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
1620      if (!Inc || !Inc->isOne())
1621        continue;
1622
1623      PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
1624      if (!Phi)
1625        continue;
1626
1627      // Check if the result of the instruction is live of the loop.
1628      bool LiveOutLoop = false;
1629      for (User *U : Inst.users()) {
1630        if ((cast<Instruction>(U))->getParent() != LoopEntry) {
1631          LiveOutLoop = true;
1632          break;
1633        }
1634      }
1635
1636      if (LiveOutLoop) {
1637        CountInst = &Inst;
1638        CountPhi = Phi;
1639        break;
1640      }
1641    }
1642
1643    if (!CountInst)
1644      return false;
1645  }
1646
1647  // step 5: check if the precondition is in this form:
1648  //   "if (x != 0) goto loop-head ; else goto somewhere-we-don't-care;"
1649  {
1650    auto *PreCondBr = dyn_cast<BranchInst>(PreCondBB->getTerminator());
1651    Value *T = matchCondition(PreCondBr, CurLoop->getLoopPreheader());
1652    if (T != PhiX->getOperand(0) && T != PhiX->getOperand(1))
1653      return false;
1654
1655    CntInst = CountInst;
1656    CntPhi = CountPhi;
1657    Var = T;
1658  }
1659
1660  return true;
1661}
1662
1663/// Return true if the idiom is detected in the loop.
1664///
1665/// Additionally:
1666/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ)
1667///       or nullptr if there is no such.
1668/// 2) \p CntPhi is set to the corresponding phi node
1669///       or nullptr if there is no such.
1670/// 3) \p Var is set to the value whose CTLZ could be used.
1671/// 4) \p DefX is set to the instruction calculating Loop exit condition.
1672///
1673/// The core idiom we are trying to detect is:
1674/// \code
1675///    if (x0 == 0)
1676///      goto loop-exit // the precondition of the loop
1677///    cnt0 = init-val;
1678///    do {
1679///       x = phi (x0, x.next);   //PhiX
1680///       cnt = phi(cnt0, cnt.next);
1681///
1682///       cnt.next = cnt + 1;
1683///        ...
1684///       x.next = x >> 1;   // DefX
1685///        ...
1686///    } while(x.next != 0);
1687///
1688/// loop-exit:
1689/// \endcode
1690static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL,
1691                                      Intrinsic::ID &IntrinID, Value *&InitX,
1692                                      Instruction *&CntInst, PHINode *&CntPhi,
1693                                      Instruction *&DefX) {
1694  BasicBlock *LoopEntry;
1695  Value *VarX = nullptr;
1696
1697  DefX = nullptr;
1698  CntInst = nullptr;
1699  CntPhi = nullptr;
1700  LoopEntry = *(CurLoop->block_begin());
1701
1702  // step 1: Check if the loop-back branch is in desirable form.
1703  if (Value *T = matchCondition(
1704          dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry))
1705    DefX = dyn_cast<Instruction>(T);
1706  else
1707    return false;
1708
1709  // step 2: detect instructions corresponding to "x.next = x >> 1 or x << 1"
1710  if (!DefX || !DefX->isShift())
1711    return false;
1712  IntrinID = DefX->getOpcode() == Instruction::Shl ? Intrinsic::cttz :
1713                                                     Intrinsic::ctlz;
1714  ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1));
1715  if (!Shft || !Shft->isOne())
1716    return false;
1717  VarX = DefX->getOperand(0);
1718
1719  // step 3: Check the recurrence of variable X
1720  PHINode *PhiX = getRecurrenceVar(VarX, DefX, LoopEntry);
1721  if (!PhiX)
1722    return false;
1723
1724  InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader());
1725
1726  // Make sure the initial value can't be negative otherwise the ashr in the
1727  // loop might never reach zero which would make the loop infinite.
1728  if (DefX->getOpcode() == Instruction::AShr && !isKnownNonNegative(InitX, DL))
1729    return false;
1730
1731  // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1
1732  //         or cnt.next = cnt + -1.
1733  // TODO: We can skip the step. If loop trip count is known (CTLZ),
1734  //       then all uses of "cnt.next" could be optimized to the trip count
1735  //       plus "cnt0". Currently it is not optimized.
1736  //       This step could be used to detect POPCNT instruction:
1737  //       cnt.next = cnt + (x.next & 1)
1738  for (Instruction &Inst : llvm::make_range(
1739           LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
1740    if (Inst.getOpcode() != Instruction::Add)
1741      continue;
1742
1743    ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
1744    if (!Inc || (!Inc->isOne() && !Inc->isMinusOne()))
1745      continue;
1746
1747    PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
1748    if (!Phi)
1749      continue;
1750
1751    CntInst = &Inst;
1752    CntPhi = Phi;
1753    break;
1754  }
1755  if (!CntInst)
1756    return false;
1757
1758  return true;
1759}
1760
1761/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop
1762/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new
1763/// trip count returns true; otherwise, returns false.
1764bool LoopIdiomRecognize::recognizeAndInsertFFS() {
1765  // Give up if the loop has multiple blocks or multiple backedges.
1766  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
1767    return false;
1768
1769  Intrinsic::ID IntrinID;
1770  Value *InitX;
1771  Instruction *DefX = nullptr;
1772  PHINode *CntPhi = nullptr;
1773  Instruction *CntInst = nullptr;
1774  // Help decide if transformation is profitable. For ShiftUntilZero idiom,
1775  // this is always 6.
1776  size_t IdiomCanonicalSize = 6;
1777
1778  if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX,
1779                                 CntInst, CntPhi, DefX))
1780    return false;
1781
1782  bool IsCntPhiUsedOutsideLoop = false;
1783  for (User *U : CntPhi->users())
1784    if (!CurLoop->contains(cast<Instruction>(U))) {
1785      IsCntPhiUsedOutsideLoop = true;
1786      break;
1787    }
1788  bool IsCntInstUsedOutsideLoop = false;
1789  for (User *U : CntInst->users())
1790    if (!CurLoop->contains(cast<Instruction>(U))) {
1791      IsCntInstUsedOutsideLoop = true;
1792      break;
1793    }
1794  // If both CntInst and CntPhi are used outside the loop the profitability
1795  // is questionable.
1796  if (IsCntInstUsedOutsideLoop && IsCntPhiUsedOutsideLoop)
1797    return false;
1798
1799  // For some CPUs result of CTLZ(X) intrinsic is undefined
1800  // when X is 0. If we can not guarantee X != 0, we need to check this
1801  // when expand.
1802  bool ZeroCheck = false;
1803  // It is safe to assume Preheader exist as it was checked in
1804  // parent function RunOnLoop.
1805  BasicBlock *PH = CurLoop->getLoopPreheader();
1806
1807  // If we are using the count instruction outside the loop, make sure we
1808  // have a zero check as a precondition. Without the check the loop would run
1809  // one iteration for before any check of the input value. This means 0 and 1
1810  // would have identical behavior in the original loop and thus
1811  if (!IsCntPhiUsedOutsideLoop) {
1812    auto *PreCondBB = PH->getSinglePredecessor();
1813    if (!PreCondBB)
1814      return false;
1815    auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
1816    if (!PreCondBI)
1817      return false;
1818    if (matchCondition(PreCondBI, PH) != InitX)
1819      return false;
1820    ZeroCheck = true;
1821  }
1822
1823  // Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always
1824  // profitable if we delete the loop.
1825
1826  // the loop has only 6 instructions:
1827  //  %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ]
1828  //  %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ]
1829  //  %shr = ashr %n.addr.0, 1
1830  //  %tobool = icmp eq %shr, 0
1831  //  %inc = add nsw %i.0, 1
1832  //  br i1 %tobool
1833
1834  const Value *Args[] = {InitX,
1835                         ConstantInt::getBool(InitX->getContext(), ZeroCheck)};
1836
1837  // @llvm.dbg doesn't count as they have no semantic effect.
1838  auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
1839  uint32_t HeaderSize =
1840      std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
1841
1842  IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args);
1843  InstructionCost Cost =
1844    TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency);
1845  if (HeaderSize != IdiomCanonicalSize &&
1846      Cost > TargetTransformInfo::TCC_Basic)
1847    return false;
1848
1849  transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX,
1850                           DefX->getDebugLoc(), ZeroCheck,
1851                           IsCntPhiUsedOutsideLoop);
1852  return true;
1853}
1854
1855/// Recognizes a population count idiom in a non-countable loop.
1856///
1857/// If detected, transforms the relevant code to issue the popcount intrinsic
1858/// function call, and returns true; otherwise, returns false.
1859bool LoopIdiomRecognize::recognizePopcount() {
1860  if (TTI->getPopcntSupport(32) != TargetTransformInfo::PSK_FastHardware)
1861    return false;
1862
1863  // Counting population are usually conducted by few arithmetic instructions.
1864  // Such instructions can be easily "absorbed" by vacant slots in a
1865  // non-compact loop. Therefore, recognizing popcount idiom only makes sense
1866  // in a compact loop.
1867
1868  // Give up if the loop has multiple blocks or multiple backedges.
1869  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
1870    return false;
1871
1872  BasicBlock *LoopBody = *(CurLoop->block_begin());
1873  if (LoopBody->size() >= 20) {
1874    // The loop is too big, bail out.
1875    return false;
1876  }
1877
1878  // It should have a preheader containing nothing but an unconditional branch.
1879  BasicBlock *PH = CurLoop->getLoopPreheader();
1880  if (!PH || &PH->front() != PH->getTerminator())
1881    return false;
1882  auto *EntryBI = dyn_cast<BranchInst>(PH->getTerminator());
1883  if (!EntryBI || EntryBI->isConditional())
1884    return false;
1885
1886  // It should have a precondition block where the generated popcount intrinsic
1887  // function can be inserted.
1888  auto *PreCondBB = PH->getSinglePredecessor();
1889  if (!PreCondBB)
1890    return false;
1891  auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
1892  if (!PreCondBI || PreCondBI->isUnconditional())
1893    return false;
1894
1895  Instruction *CntInst;
1896  PHINode *CntPhi;
1897  Value *Val;
1898  if (!detectPopcountIdiom(CurLoop, PreCondBB, CntInst, CntPhi, Val))
1899    return false;
1900
1901  transformLoopToPopcount(PreCondBB, CntInst, CntPhi, Val);
1902  return true;
1903}
1904
1905static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val,
1906                                       const DebugLoc &DL) {
1907  Value *Ops[] = {Val};
1908  Type *Tys[] = {Val->getType()};
1909
1910  Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent();
1911  Function *Func = Intrinsic::getDeclaration(M, Intrinsic::ctpop, Tys);
1912  CallInst *CI = IRBuilder.CreateCall(Func, Ops);
1913  CI->setDebugLoc(DL);
1914
1915  return CI;
1916}
1917
1918static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val,
1919                                    const DebugLoc &DL, bool ZeroCheck,
1920                                    Intrinsic::ID IID) {
1921  Value *Ops[] = {Val, IRBuilder.getInt1(ZeroCheck)};
1922  Type *Tys[] = {Val->getType()};
1923
1924  Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent();
1925  Function *Func = Intrinsic::getDeclaration(M, IID, Tys);
1926  CallInst *CI = IRBuilder.CreateCall(Func, Ops);
1927  CI->setDebugLoc(DL);
1928
1929  return CI;
1930}
1931
1932/// Transform the following loop (Using CTLZ, CTTZ is similar):
1933/// loop:
1934///   CntPhi = PHI [Cnt0, CntInst]
1935///   PhiX = PHI [InitX, DefX]
1936///   CntInst = CntPhi + 1
1937///   DefX = PhiX >> 1
1938///   LOOP_BODY
1939///   Br: loop if (DefX != 0)
1940/// Use(CntPhi) or Use(CntInst)
1941///
1942/// Into:
1943/// If CntPhi used outside the loop:
1944///   CountPrev = BitWidth(InitX) - CTLZ(InitX >> 1)
1945///   Count = CountPrev + 1
1946/// else
1947///   Count = BitWidth(InitX) - CTLZ(InitX)
1948/// loop:
1949///   CntPhi = PHI [Cnt0, CntInst]
1950///   PhiX = PHI [InitX, DefX]
1951///   PhiCount = PHI [Count, Dec]
1952///   CntInst = CntPhi + 1
1953///   DefX = PhiX >> 1
1954///   Dec = PhiCount - 1
1955///   LOOP_BODY
1956///   Br: loop if (Dec != 0)
1957/// Use(CountPrev + Cnt0) // Use(CntPhi)
1958/// or
1959/// Use(Count + Cnt0) // Use(CntInst)
1960///
1961/// If LOOP_BODY is empty the loop will be deleted.
1962/// If CntInst and DefX are not used in LOOP_BODY they will be removed.
1963void LoopIdiomRecognize::transformLoopToCountable(
1964    Intrinsic::ID IntrinID, BasicBlock *Preheader, Instruction *CntInst,
1965    PHINode *CntPhi, Value *InitX, Instruction *DefX, const DebugLoc &DL,
1966    bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) {
1967  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
1968
1969  // Step 1: Insert the CTLZ/CTTZ instruction at the end of the preheader block
1970  IRBuilder<> Builder(PreheaderBr);
1971  Builder.SetCurrentDebugLocation(DL);
1972
1973  // If there are no uses of CntPhi crate:
1974  //   Count = BitWidth - CTLZ(InitX);
1975  //   NewCount = Count;
1976  // If there are uses of CntPhi create:
1977  //   NewCount = BitWidth - CTLZ(InitX >> 1);
1978  //   Count = NewCount + 1;
1979  Value *InitXNext;
1980  if (IsCntPhiUsedOutsideLoop) {
1981    if (DefX->getOpcode() == Instruction::AShr)
1982      InitXNext = Builder.CreateAShr(InitX, 1);
1983    else if (DefX->getOpcode() == Instruction::LShr)
1984      InitXNext = Builder.CreateLShr(InitX, 1);
1985    else if (DefX->getOpcode() == Instruction::Shl) // cttz
1986      InitXNext = Builder.CreateShl(InitX, 1);
1987    else
1988      llvm_unreachable("Unexpected opcode!");
1989  } else
1990    InitXNext = InitX;
1991  Value *Count =
1992      createFFSIntrinsic(Builder, InitXNext, DL, ZeroCheck, IntrinID);
1993  Type *CountTy = Count->getType();
1994  Count = Builder.CreateSub(
1995      ConstantInt::get(CountTy, CountTy->getIntegerBitWidth()), Count);
1996  Value *NewCount = Count;
1997  if (IsCntPhiUsedOutsideLoop)
1998    Count = Builder.CreateAdd(Count, ConstantInt::get(CountTy, 1));
1999
2000  NewCount = Builder.CreateZExtOrTrunc(NewCount, CntInst->getType());
2001
2002  Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader);
2003  if (cast<ConstantInt>(CntInst->getOperand(1))->isOne()) {
2004    // If the counter was being incremented in the loop, add NewCount to the
2005    // counter's initial value, but only if the initial value is not zero.
2006    ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
2007    if (!InitConst || !InitConst->isZero())
2008      NewCount = Builder.CreateAdd(NewCount, CntInitVal);
2009  } else {
2010    // If the count was being decremented in the loop, subtract NewCount from
2011    // the counter's initial value.
2012    NewCount = Builder.CreateSub(CntInitVal, NewCount);
2013  }
2014
2015  // Step 2: Insert new IV and loop condition:
2016  // loop:
2017  //   ...
2018  //   PhiCount = PHI [Count, Dec]
2019  //   ...
2020  //   Dec = PhiCount - 1
2021  //   ...
2022  //   Br: loop if (Dec != 0)
2023  BasicBlock *Body = *(CurLoop->block_begin());
2024  auto *LbBr = cast<BranchInst>(Body->getTerminator());
2025  ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition());
2026
2027  PHINode *TcPhi = PHINode::Create(CountTy, 2, "tcphi");
2028  TcPhi->insertBefore(Body->begin());
2029
2030  Builder.SetInsertPoint(LbCond);
2031  Instruction *TcDec = cast<Instruction>(Builder.CreateSub(
2032      TcPhi, ConstantInt::get(CountTy, 1), "tcdec", false, true));
2033
2034  TcPhi->addIncoming(Count, Preheader);
2035  TcPhi->addIncoming(TcDec, Body);
2036
2037  CmpInst::Predicate Pred =
2038      (LbBr->getSuccessor(0) == Body) ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
2039  LbCond->setPredicate(Pred);
2040  LbCond->setOperand(0, TcDec);
2041  LbCond->setOperand(1, ConstantInt::get(CountTy, 0));
2042
2043  // Step 3: All the references to the original counter outside
2044  //  the loop are replaced with the NewCount
2045  if (IsCntPhiUsedOutsideLoop)
2046    CntPhi->replaceUsesOutsideBlock(NewCount, Body);
2047  else
2048    CntInst->replaceUsesOutsideBlock(NewCount, Body);
2049
2050  // step 4: Forget the "non-computable" trip-count SCEV associated with the
2051  //   loop. The loop would otherwise not be deleted even if it becomes empty.
2052  SE->forgetLoop(CurLoop);
2053}
2054
2055void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB,
2056                                                 Instruction *CntInst,
2057                                                 PHINode *CntPhi, Value *Var) {
2058  BasicBlock *PreHead = CurLoop->getLoopPreheader();
2059  auto *PreCondBr = cast<BranchInst>(PreCondBB->getTerminator());
2060  const DebugLoc &DL = CntInst->getDebugLoc();
2061
2062  // Assuming before transformation, the loop is following:
2063  //  if (x) // the precondition
2064  //     do { cnt++; x &= x - 1; } while(x);
2065
2066  // Step 1: Insert the ctpop instruction at the end of the precondition block
2067  IRBuilder<> Builder(PreCondBr);
2068  Value *PopCnt, *PopCntZext, *NewCount, *TripCnt;
2069  {
2070    PopCnt = createPopcntIntrinsic(Builder, Var, DL);
2071    NewCount = PopCntZext =
2072        Builder.CreateZExtOrTrunc(PopCnt, cast<IntegerType>(CntPhi->getType()));
2073
2074    if (NewCount != PopCnt)
2075      (cast<Instruction>(NewCount))->setDebugLoc(DL);
2076
2077    // TripCnt is exactly the number of iterations the loop has
2078    TripCnt = NewCount;
2079
2080    // If the population counter's initial value is not zero, insert Add Inst.
2081    Value *CntInitVal = CntPhi->getIncomingValueForBlock(PreHead);
2082    ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal);
2083    if (!InitConst || !InitConst->isZero()) {
2084      NewCount = Builder.CreateAdd(NewCount, CntInitVal);
2085      (cast<Instruction>(NewCount))->setDebugLoc(DL);
2086    }
2087  }
2088
2089  // Step 2: Replace the precondition from "if (x == 0) goto loop-exit" to
2090  //   "if (NewCount == 0) loop-exit". Without this change, the intrinsic
2091  //   function would be partial dead code, and downstream passes will drag
2092  //   it back from the precondition block to the preheader.
2093  {
2094    ICmpInst *PreCond = cast<ICmpInst>(PreCondBr->getCondition());
2095
2096    Value *Opnd0 = PopCntZext;
2097    Value *Opnd1 = ConstantInt::get(PopCntZext->getType(), 0);
2098    if (PreCond->getOperand(0) != Var)
2099      std::swap(Opnd0, Opnd1);
2100
2101    ICmpInst *NewPreCond = cast<ICmpInst>(
2102        Builder.CreateICmp(PreCond->getPredicate(), Opnd0, Opnd1));
2103    PreCondBr->setCondition(NewPreCond);
2104
2105    RecursivelyDeleteTriviallyDeadInstructions(PreCond, TLI);
2106  }
2107
2108  // Step 3: Note that the population count is exactly the trip count of the
2109  // loop in question, which enable us to convert the loop from noncountable
2110  // loop into a countable one. The benefit is twofold:
2111  //
2112  //  - If the loop only counts population, the entire loop becomes dead after
2113  //    the transformation. It is a lot easier to prove a countable loop dead
2114  //    than to prove a noncountable one. (In some C dialects, an infinite loop
2115  //    isn't dead even if it computes nothing useful. In general, DCE needs
2116  //    to prove a noncountable loop finite before safely delete it.)
2117  //
2118  //  - If the loop also performs something else, it remains alive.
2119  //    Since it is transformed to countable form, it can be aggressively
2120  //    optimized by some optimizations which are in general not applicable
2121  //    to a noncountable loop.
2122  //
2123  // After this step, this loop (conceptually) would look like following:
2124  //   newcnt = __builtin_ctpop(x);
2125  //   t = newcnt;
2126  //   if (x)
2127  //     do { cnt++; x &= x-1; t--) } while (t > 0);
2128  BasicBlock *Body = *(CurLoop->block_begin());
2129  {
2130    auto *LbBr = cast<BranchInst>(Body->getTerminator());
2131    ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition());
2132    Type *Ty = TripCnt->getType();
2133
2134    PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi");
2135    TcPhi->insertBefore(Body->begin());
2136
2137    Builder.SetInsertPoint(LbCond);
2138    Instruction *TcDec = cast<Instruction>(
2139        Builder.CreateSub(TcPhi, ConstantInt::get(Ty, 1),
2140                          "tcdec", false, true));
2141
2142    TcPhi->addIncoming(TripCnt, PreHead);
2143    TcPhi->addIncoming(TcDec, Body);
2144
2145    CmpInst::Predicate Pred =
2146        (LbBr->getSuccessor(0) == Body) ? CmpInst::ICMP_UGT : CmpInst::ICMP_SLE;
2147    LbCond->setPredicate(Pred);
2148    LbCond->setOperand(0, TcDec);
2149    LbCond->setOperand(1, ConstantInt::get(Ty, 0));
2150  }
2151
2152  // Step 4: All the references to the original population counter outside
2153  //  the loop are replaced with the NewCount -- the value returned from
2154  //  __builtin_ctpop().
2155  CntInst->replaceUsesOutsideBlock(NewCount, Body);
2156
2157  // step 5: Forget the "non-computable" trip-count SCEV associated with the
2158  //   loop. The loop would otherwise not be deleted even if it becomes empty.
2159  SE->forgetLoop(CurLoop);
2160}
2161
2162/// Match loop-invariant value.
2163template <typename SubPattern_t> struct match_LoopInvariant {
2164  SubPattern_t SubPattern;
2165  const Loop *L;
2166
2167  match_LoopInvariant(const SubPattern_t &SP, const Loop *L)
2168      : SubPattern(SP), L(L) {}
2169
2170  template <typename ITy> bool match(ITy *V) {
2171    return L->isLoopInvariant(V) && SubPattern.match(V);
2172  }
2173};
2174
2175/// Matches if the value is loop-invariant.
2176template <typename Ty>
2177inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) {
2178  return match_LoopInvariant<Ty>(M, L);
2179}
2180
2181/// Return true if the idiom is detected in the loop.
2182///
2183/// The core idiom we are trying to detect is:
2184/// \code
2185///   entry:
2186///     <...>
2187///     %bitmask = shl i32 1, %bitpos
2188///     br label %loop
2189///
2190///   loop:
2191///     %x.curr = phi i32 [ %x, %entry ], [ %x.next, %loop ]
2192///     %x.curr.bitmasked = and i32 %x.curr, %bitmask
2193///     %x.curr.isbitunset = icmp eq i32 %x.curr.bitmasked, 0
2194///     %x.next = shl i32 %x.curr, 1
2195///     <...>
2196///     br i1 %x.curr.isbitunset, label %loop, label %end
2197///
2198///   end:
2199///     %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2200///     %x.next.res = phi i32 [ %x.next, %loop ] <...>
2201///     <...>
2202/// \endcode
2203static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
2204                                         Value *&BitMask, Value *&BitPos,
2205                                         Value *&CurrX, Instruction *&NextX) {
2206  LLVM_DEBUG(dbgs() << DEBUG_TYPE
2207             " Performing shift-until-bittest idiom detection.\n");
2208
2209  // Give up if the loop has multiple blocks or multiple backedges.
2210  if (CurLoop->getNumBlocks() != 1 || CurLoop->getNumBackEdges() != 1) {
2211    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad block/backedge count.\n");
2212    return false;
2213  }
2214
2215  BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2216  BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2217  assert(LoopPreheaderBB && "There is always a loop preheader.");
2218
2219  using namespace PatternMatch;
2220
2221  // Step 1: Check if the loop backedge is in desirable form.
2222
2223  ICmpInst::Predicate Pred;
2224  Value *CmpLHS, *CmpRHS;
2225  BasicBlock *TrueBB, *FalseBB;
2226  if (!match(LoopHeaderBB->getTerminator(),
2227             m_Br(m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)),
2228                  m_BasicBlock(TrueBB), m_BasicBlock(FalseBB)))) {
2229    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge structure.\n");
2230    return false;
2231  }
2232
2233  // Step 2: Check if the backedge's condition is in desirable form.
2234
2235  auto MatchVariableBitMask = [&]() {
2236    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
2237           match(CmpLHS,
2238                 m_c_And(m_Value(CurrX),
2239                         m_CombineAnd(
2240                             m_Value(BitMask),
2241                             m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
2242                                             CurLoop))));
2243  };
2244  auto MatchConstantBitMask = [&]() {
2245    return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
2246           match(CmpLHS, m_And(m_Value(CurrX),
2247                               m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
2248           (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
2249  };
2250  auto MatchDecomposableConstantBitMask = [&]() {
2251    APInt Mask;
2252    return llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CurrX, Mask) &&
2253           ICmpInst::isEquality(Pred) && Mask.isPowerOf2() &&
2254           (BitMask = ConstantInt::get(CurrX->getType(), Mask)) &&
2255           (BitPos = ConstantInt::get(CurrX->getType(), Mask.logBase2()));
2256  };
2257
2258  if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
2259      !MatchDecomposableConstantBitMask()) {
2260    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
2261    return false;
2262  }
2263
2264  // Step 3: Check if the recurrence is in desirable form.
2265  auto *CurrXPN = dyn_cast<PHINode>(CurrX);
2266  if (!CurrXPN || CurrXPN->getParent() != LoopHeaderBB) {
2267    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Not an expected PHI node.\n");
2268    return false;
2269  }
2270
2271  BaseX = CurrXPN->getIncomingValueForBlock(LoopPreheaderBB);
2272  NextX =
2273      dyn_cast<Instruction>(CurrXPN->getIncomingValueForBlock(LoopHeaderBB));
2274
2275  assert(CurLoop->isLoopInvariant(BaseX) &&
2276         "Expected BaseX to be avaliable in the preheader!");
2277
2278  if (!NextX || !match(NextX, m_Shl(m_Specific(CurrX), m_One()))) {
2279    // FIXME: support right-shift?
2280    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad recurrence.\n");
2281    return false;
2282  }
2283
2284  // Step 4: Check if the backedge's destinations are in desirable form.
2285
2286  assert(ICmpInst::isEquality(Pred) &&
2287         "Should only get equality predicates here.");
2288
2289  // cmp-br is commutative, so canonicalize to a single variant.
2290  if (Pred != ICmpInst::Predicate::ICMP_EQ) {
2291    Pred = ICmpInst::getInversePredicate(Pred);
2292    std::swap(TrueBB, FalseBB);
2293  }
2294
2295  // We expect to exit loop when comparison yields false,
2296  // so when it yields true we should branch back to loop header.
2297  if (TrueBB != LoopHeaderBB) {
2298    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge flow.\n");
2299    return false;
2300  }
2301
2302  // Okay, idiom checks out.
2303  return true;
2304}
2305
2306/// Look for the following loop:
2307/// \code
2308///   entry:
2309///     <...>
2310///     %bitmask = shl i32 1, %bitpos
2311///     br label %loop
2312///
2313///   loop:
2314///     %x.curr = phi i32 [ %x, %entry ], [ %x.next, %loop ]
2315///     %x.curr.bitmasked = and i32 %x.curr, %bitmask
2316///     %x.curr.isbitunset = icmp eq i32 %x.curr.bitmasked, 0
2317///     %x.next = shl i32 %x.curr, 1
2318///     <...>
2319///     br i1 %x.curr.isbitunset, label %loop, label %end
2320///
2321///   end:
2322///     %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2323///     %x.next.res = phi i32 [ %x.next, %loop ] <...>
2324///     <...>
2325/// \endcode
2326///
2327/// And transform it into:
2328/// \code
2329///   entry:
2330///     %bitmask = shl i32 1, %bitpos
2331///     %lowbitmask = add i32 %bitmask, -1
2332///     %mask = or i32 %lowbitmask, %bitmask
2333///     %x.masked = and i32 %x, %mask
2334///     %x.masked.numleadingzeros = call i32 @llvm.ctlz.i32(i32 %x.masked,
2335///                                                         i1 true)
2336///     %x.masked.numactivebits = sub i32 32, %x.masked.numleadingzeros
2337///     %x.masked.leadingonepos = add i32 %x.masked.numactivebits, -1
2338///     %backedgetakencount = sub i32 %bitpos, %x.masked.leadingonepos
2339///     %tripcount = add i32 %backedgetakencount, 1
2340///     %x.curr = shl i32 %x, %backedgetakencount
2341///     %x.next = shl i32 %x, %tripcount
2342///     br label %loop
2343///
2344///   loop:
2345///     %loop.iv = phi i32 [ 0, %entry ], [ %loop.iv.next, %loop ]
2346///     %loop.iv.next = add nuw i32 %loop.iv, 1
2347///     %loop.ivcheck = icmp eq i32 %loop.iv.next, %tripcount
2348///     <...>
2349///     br i1 %loop.ivcheck, label %end, label %loop
2350///
2351///   end:
2352///     %x.curr.res = phi i32 [ %x.curr, %loop ] <...>
2353///     %x.next.res = phi i32 [ %x.next, %loop ] <...>
2354///     <...>
2355/// \endcode
2356bool LoopIdiomRecognize::recognizeShiftUntilBitTest() {
2357  bool MadeChange = false;
2358
2359  Value *X, *BitMask, *BitPos, *XCurr;
2360  Instruction *XNext;
2361  if (!detectShiftUntilBitTestIdiom(CurLoop, X, BitMask, BitPos, XCurr,
2362                                    XNext)) {
2363    LLVM_DEBUG(dbgs() << DEBUG_TYPE
2364               " shift-until-bittest idiom detection failed.\n");
2365    return MadeChange;
2366  }
2367  LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-bittest idiom detected!\n");
2368
2369  // Ok, it is the idiom we were looking for, we *could* transform this loop,
2370  // but is it profitable to transform?
2371
2372  BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2373  BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2374  assert(LoopPreheaderBB && "There is always a loop preheader.");
2375
2376  BasicBlock *SuccessorBB = CurLoop->getExitBlock();
2377  assert(SuccessorBB && "There is only a single successor.");
2378
2379  IRBuilder<> Builder(LoopPreheaderBB->getTerminator());
2380  Builder.SetCurrentDebugLocation(cast<Instruction>(XCurr)->getDebugLoc());
2381
2382  Intrinsic::ID IntrID = Intrinsic::ctlz;
2383  Type *Ty = X->getType();
2384  unsigned Bitwidth = Ty->getScalarSizeInBits();
2385
2386  TargetTransformInfo::TargetCostKind CostKind =
2387      TargetTransformInfo::TCK_SizeAndLatency;
2388
2389  // The rewrite is considered to be unprofitable iff and only iff the
2390  // intrinsic/shift we'll use are not cheap. Note that we are okay with *just*
2391  // making the loop countable, even if nothing else changes.
2392  IntrinsicCostAttributes Attrs(
2393      IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getTrue()});
2394  InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind);
2395  if (Cost > TargetTransformInfo::TCC_Basic) {
2396    LLVM_DEBUG(dbgs() << DEBUG_TYPE
2397               " Intrinsic is too costly, not beneficial\n");
2398    return MadeChange;
2399  }
2400  if (TTI->getArithmeticInstrCost(Instruction::Shl, Ty, CostKind) >
2401      TargetTransformInfo::TCC_Basic) {
2402    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Shift is too costly, not beneficial\n");
2403    return MadeChange;
2404  }
2405
2406  // Ok, transform appears worthwhile.
2407  MadeChange = true;
2408
2409  if (!isGuaranteedNotToBeUndefOrPoison(BitPos)) {
2410    // BitMask may be computed from BitPos, Freeze BitPos so we can increase
2411    // it's use count.
2412    Instruction *InsertPt = nullptr;
2413    if (auto *BitPosI = dyn_cast<Instruction>(BitPos))
2414      InsertPt = &**BitPosI->getInsertionPointAfterDef();
2415    else
2416      InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca();
2417    if (!InsertPt)
2418      return false;
2419    FreezeInst *BitPosFrozen =
2420        new FreezeInst(BitPos, BitPos->getName() + ".fr", InsertPt);
2421    BitPos->replaceUsesWithIf(BitPosFrozen, [BitPosFrozen](Use &U) {
2422      return U.getUser() != BitPosFrozen;
2423    });
2424    BitPos = BitPosFrozen;
2425  }
2426
2427  // Step 1: Compute the loop trip count.
2428
2429  Value *LowBitMask = Builder.CreateAdd(BitMask, Constant::getAllOnesValue(Ty),
2430                                        BitPos->getName() + ".lowbitmask");
2431  Value *Mask =
2432      Builder.CreateOr(LowBitMask, BitMask, BitPos->getName() + ".mask");
2433  Value *XMasked = Builder.CreateAnd(X, Mask, X->getName() + ".masked");
2434  CallInst *XMaskedNumLeadingZeros = Builder.CreateIntrinsic(
2435      IntrID, Ty, {XMasked, /*is_zero_poison=*/Builder.getTrue()},
2436      /*FMFSource=*/nullptr, XMasked->getName() + ".numleadingzeros");
2437  Value *XMaskedNumActiveBits = Builder.CreateSub(
2438      ConstantInt::get(Ty, Ty->getScalarSizeInBits()), XMaskedNumLeadingZeros,
2439      XMasked->getName() + ".numactivebits", /*HasNUW=*/true,
2440      /*HasNSW=*/Bitwidth != 2);
2441  Value *XMaskedLeadingOnePos =
2442      Builder.CreateAdd(XMaskedNumActiveBits, Constant::getAllOnesValue(Ty),
2443                        XMasked->getName() + ".leadingonepos", /*HasNUW=*/false,
2444                        /*HasNSW=*/Bitwidth > 2);
2445
2446  Value *LoopBackedgeTakenCount = Builder.CreateSub(
2447      BitPos, XMaskedLeadingOnePos, CurLoop->getName() + ".backedgetakencount",
2448      /*HasNUW=*/true, /*HasNSW=*/true);
2449  // We know loop's backedge-taken count, but what's loop's trip count?
2450  // Note that while NUW is always safe, while NSW is only for bitwidths != 2.
2451  Value *LoopTripCount =
2452      Builder.CreateAdd(LoopBackedgeTakenCount, ConstantInt::get(Ty, 1),
2453                        CurLoop->getName() + ".tripcount", /*HasNUW=*/true,
2454                        /*HasNSW=*/Bitwidth != 2);
2455
2456  // Step 2: Compute the recurrence's final value without a loop.
2457
2458  // NewX is always safe to compute, because `LoopBackedgeTakenCount`
2459  // will always be smaller than `bitwidth(X)`, i.e. we never get poison.
2460  Value *NewX = Builder.CreateShl(X, LoopBackedgeTakenCount);
2461  NewX->takeName(XCurr);
2462  if (auto *I = dyn_cast<Instruction>(NewX))
2463    I->copyIRFlags(XNext, /*IncludeWrapFlags=*/true);
2464
2465  Value *NewXNext;
2466  // Rewriting XNext is more complicated, however, because `X << LoopTripCount`
2467  // will be poison iff `LoopTripCount == bitwidth(X)` (which will happen
2468  // iff `BitPos` is `bitwidth(x) - 1` and `X` is `1`). So unless we know
2469  // that isn't the case, we'll need to emit an alternative, safe IR.
2470  if (XNext->hasNoSignedWrap() || XNext->hasNoUnsignedWrap() ||
2471      PatternMatch::match(
2472          BitPos, PatternMatch::m_SpecificInt_ICMP(
2473                      ICmpInst::ICMP_NE, APInt(Ty->getScalarSizeInBits(),
2474                                               Ty->getScalarSizeInBits() - 1))))
2475    NewXNext = Builder.CreateShl(X, LoopTripCount);
2476  else {
2477    // Otherwise, just additionally shift by one. It's the smallest solution,
2478    // alternatively, we could check that NewX is INT_MIN (or BitPos is )
2479    // and select 0 instead.
2480    NewXNext = Builder.CreateShl(NewX, ConstantInt::get(Ty, 1));
2481  }
2482
2483  NewXNext->takeName(XNext);
2484  if (auto *I = dyn_cast<Instruction>(NewXNext))
2485    I->copyIRFlags(XNext, /*IncludeWrapFlags=*/true);
2486
2487  // Step 3: Adjust the successor basic block to recieve the computed
2488  //         recurrence's final value instead of the recurrence itself.
2489
2490  XCurr->replaceUsesOutsideBlock(NewX, LoopHeaderBB);
2491  XNext->replaceUsesOutsideBlock(NewXNext, LoopHeaderBB);
2492
2493  // Step 4: Rewrite the loop into a countable form, with canonical IV.
2494
2495  // The new canonical induction variable.
2496  Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin());
2497  auto *IV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv");
2498
2499  // The induction itself.
2500  // Note that while NUW is always safe, while NSW is only for bitwidths != 2.
2501  Builder.SetInsertPoint(LoopHeaderBB->getTerminator());
2502  auto *IVNext =
2503      Builder.CreateAdd(IV, ConstantInt::get(Ty, 1), IV->getName() + ".next",
2504                        /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
2505
2506  // The loop trip count check.
2507  auto *IVCheck = Builder.CreateICmpEQ(IVNext, LoopTripCount,
2508                                       CurLoop->getName() + ".ivcheck");
2509  Builder.CreateCondBr(IVCheck, SuccessorBB, LoopHeaderBB);
2510  LoopHeaderBB->getTerminator()->eraseFromParent();
2511
2512  // Populate the IV PHI.
2513  IV->addIncoming(ConstantInt::get(Ty, 0), LoopPreheaderBB);
2514  IV->addIncoming(IVNext, LoopHeaderBB);
2515
2516  // Step 5: Forget the "non-computable" trip-count SCEV associated with the
2517  //   loop. The loop would otherwise not be deleted even if it becomes empty.
2518
2519  SE->forgetLoop(CurLoop);
2520
2521  // Other passes will take care of actually deleting the loop if possible.
2522
2523  LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-bittest idiom optimized!\n");
2524
2525  ++NumShiftUntilBitTest;
2526  return MadeChange;
2527}
2528
2529/// Return true if the idiom is detected in the loop.
2530///
2531/// The core idiom we are trying to detect is:
2532/// \code
2533///   entry:
2534///     <...>
2535///     %start = <...>
2536///     %extraoffset = <...>
2537///     <...>
2538///     br label %for.cond
2539///
2540///   loop:
2541///     %iv = phi i8 [ %start, %entry ], [ %iv.next, %for.cond ]
2542///     %nbits = add nsw i8 %iv, %extraoffset
2543///     %val.shifted = {{l,a}shr,shl} i8 %val, %nbits
2544///     %val.shifted.iszero = icmp eq i8 %val.shifted, 0
2545///     %iv.next = add i8 %iv, 1
2546///     <...>
2547///     br i1 %val.shifted.iszero, label %end, label %loop
2548///
2549///   end:
2550///     %iv.res = phi i8 [ %iv, %loop ] <...>
2551///     %nbits.res = phi i8 [ %nbits, %loop ] <...>
2552///     %val.shifted.res = phi i8 [ %val.shifted, %loop ] <...>
2553///     %val.shifted.iszero.res = phi i1 [ %val.shifted.iszero, %loop ] <...>
2554///     %iv.next.res = phi i8 [ %iv.next, %loop ] <...>
2555///     <...>
2556/// \endcode
2557static bool detectShiftUntilZeroIdiom(Loop *CurLoop, ScalarEvolution *SE,
2558                                      Instruction *&ValShiftedIsZero,
2559                                      Intrinsic::ID &IntrinID, Instruction *&IV,
2560                                      Value *&Start, Value *&Val,
2561                                      const SCEV *&ExtraOffsetExpr,
2562                                      bool &InvertedCond) {
2563  LLVM_DEBUG(dbgs() << DEBUG_TYPE
2564             " Performing shift-until-zero idiom detection.\n");
2565
2566  // Give up if the loop has multiple blocks or multiple backedges.
2567  if (CurLoop->getNumBlocks() != 1 || CurLoop->getNumBackEdges() != 1) {
2568    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad block/backedge count.\n");
2569    return false;
2570  }
2571
2572  Instruction *ValShifted, *NBits, *IVNext;
2573  Value *ExtraOffset;
2574
2575  BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2576  BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2577  assert(LoopPreheaderBB && "There is always a loop preheader.");
2578
2579  using namespace PatternMatch;
2580
2581  // Step 1: Check if the loop backedge, condition is in desirable form.
2582
2583  ICmpInst::Predicate Pred;
2584  BasicBlock *TrueBB, *FalseBB;
2585  if (!match(LoopHeaderBB->getTerminator(),
2586             m_Br(m_Instruction(ValShiftedIsZero), m_BasicBlock(TrueBB),
2587                  m_BasicBlock(FalseBB))) ||
2588      !match(ValShiftedIsZero,
2589             m_ICmp(Pred, m_Instruction(ValShifted), m_Zero())) ||
2590      !ICmpInst::isEquality(Pred)) {
2591    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge structure.\n");
2592    return false;
2593  }
2594
2595  // Step 2: Check if the comparison's operand is in desirable form.
2596  // FIXME: Val could be a one-input PHI node, which we should look past.
2597  if (!match(ValShifted, m_Shift(m_LoopInvariant(m_Value(Val), CurLoop),
2598                                 m_Instruction(NBits)))) {
2599    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad comparisons value computation.\n");
2600    return false;
2601  }
2602  IntrinID = ValShifted->getOpcode() == Instruction::Shl ? Intrinsic::cttz
2603                                                         : Intrinsic::ctlz;
2604
2605  // Step 3: Check if the shift amount is in desirable form.
2606
2607  if (match(NBits, m_c_Add(m_Instruction(IV),
2608                           m_LoopInvariant(m_Value(ExtraOffset), CurLoop))) &&
2609      (NBits->hasNoSignedWrap() || NBits->hasNoUnsignedWrap()))
2610    ExtraOffsetExpr = SE->getNegativeSCEV(SE->getSCEV(ExtraOffset));
2611  else if (match(NBits,
2612                 m_Sub(m_Instruction(IV),
2613                       m_LoopInvariant(m_Value(ExtraOffset), CurLoop))) &&
2614           NBits->hasNoSignedWrap())
2615    ExtraOffsetExpr = SE->getSCEV(ExtraOffset);
2616  else {
2617    IV = NBits;
2618    ExtraOffsetExpr = SE->getZero(NBits->getType());
2619  }
2620
2621  // Step 4: Check if the recurrence is in desirable form.
2622  auto *IVPN = dyn_cast<PHINode>(IV);
2623  if (!IVPN || IVPN->getParent() != LoopHeaderBB) {
2624    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Not an expected PHI node.\n");
2625    return false;
2626  }
2627
2628  Start = IVPN->getIncomingValueForBlock(LoopPreheaderBB);
2629  IVNext = dyn_cast<Instruction>(IVPN->getIncomingValueForBlock(LoopHeaderBB));
2630
2631  if (!IVNext || !match(IVNext, m_Add(m_Specific(IVPN), m_One()))) {
2632    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad recurrence.\n");
2633    return false;
2634  }
2635
2636  // Step 4: Check if the backedge's destinations are in desirable form.
2637
2638  assert(ICmpInst::isEquality(Pred) &&
2639         "Should only get equality predicates here.");
2640
2641  // cmp-br is commutative, so canonicalize to a single variant.
2642  InvertedCond = Pred != ICmpInst::Predicate::ICMP_EQ;
2643  if (InvertedCond) {
2644    Pred = ICmpInst::getInversePredicate(Pred);
2645    std::swap(TrueBB, FalseBB);
2646  }
2647
2648  // We expect to exit loop when comparison yields true,
2649  // so when it yields false we should branch back to loop header.
2650  if (FalseBB != LoopHeaderBB) {
2651    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge flow.\n");
2652    return false;
2653  }
2654
2655  // The new, countable, loop will certainly only run a known number of
2656  // iterations, It won't be infinite. But the old loop might be infinite
2657  // under certain conditions. For logical shifts, the value will become zero
2658  // after at most bitwidth(%Val) loop iterations. However, for arithmetic
2659  // right-shift, iff the sign bit was set, the value will never become zero,
2660  // and the loop may never finish.
2661  if (ValShifted->getOpcode() == Instruction::AShr &&
2662      !isMustProgress(CurLoop) && !SE->isKnownNonNegative(SE->getSCEV(Val))) {
2663    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Can not prove the loop is finite.\n");
2664    return false;
2665  }
2666
2667  // Okay, idiom checks out.
2668  return true;
2669}
2670
2671/// Look for the following loop:
2672/// \code
2673///   entry:
2674///     <...>
2675///     %start = <...>
2676///     %extraoffset = <...>
2677///     <...>
2678///     br label %for.cond
2679///
2680///   loop:
2681///     %iv = phi i8 [ %start, %entry ], [ %iv.next, %for.cond ]
2682///     %nbits = add nsw i8 %iv, %extraoffset
2683///     %val.shifted = {{l,a}shr,shl} i8 %val, %nbits
2684///     %val.shifted.iszero = icmp eq i8 %val.shifted, 0
2685///     %iv.next = add i8 %iv, 1
2686///     <...>
2687///     br i1 %val.shifted.iszero, label %end, label %loop
2688///
2689///   end:
2690///     %iv.res = phi i8 [ %iv, %loop ] <...>
2691///     %nbits.res = phi i8 [ %nbits, %loop ] <...>
2692///     %val.shifted.res = phi i8 [ %val.shifted, %loop ] <...>
2693///     %val.shifted.iszero.res = phi i1 [ %val.shifted.iszero, %loop ] <...>
2694///     %iv.next.res = phi i8 [ %iv.next, %loop ] <...>
2695///     <...>
2696/// \endcode
2697///
2698/// And transform it into:
2699/// \code
2700///   entry:
2701///     <...>
2702///     %start = <...>
2703///     %extraoffset = <...>
2704///     <...>
2705///     %val.numleadingzeros = call i8 @llvm.ct{l,t}z.i8(i8 %val, i1 0)
2706///     %val.numactivebits = sub i8 8, %val.numleadingzeros
2707///     %extraoffset.neg = sub i8 0, %extraoffset
2708///     %tmp = add i8 %val.numactivebits, %extraoffset.neg
2709///     %iv.final = call i8 @llvm.smax.i8(i8 %tmp, i8 %start)
2710///     %loop.tripcount = sub i8 %iv.final, %start
2711///     br label %loop
2712///
2713///   loop:
2714///     %loop.iv = phi i8 [ 0, %entry ], [ %loop.iv.next, %loop ]
2715///     %loop.iv.next = add i8 %loop.iv, 1
2716///     %loop.ivcheck = icmp eq i8 %loop.iv.next, %loop.tripcount
2717///     %iv = add i8 %loop.iv, %start
2718///     <...>
2719///     br i1 %loop.ivcheck, label %end, label %loop
2720///
2721///   end:
2722///     %iv.res = phi i8 [ %iv.final, %loop ] <...>
2723///     <...>
2724/// \endcode
2725bool LoopIdiomRecognize::recognizeShiftUntilZero() {
2726  bool MadeChange = false;
2727
2728  Instruction *ValShiftedIsZero;
2729  Intrinsic::ID IntrID;
2730  Instruction *IV;
2731  Value *Start, *Val;
2732  const SCEV *ExtraOffsetExpr;
2733  bool InvertedCond;
2734  if (!detectShiftUntilZeroIdiom(CurLoop, SE, ValShiftedIsZero, IntrID, IV,
2735                                 Start, Val, ExtraOffsetExpr, InvertedCond)) {
2736    LLVM_DEBUG(dbgs() << DEBUG_TYPE
2737               " shift-until-zero idiom detection failed.\n");
2738    return MadeChange;
2739  }
2740  LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-zero idiom detected!\n");
2741
2742  // Ok, it is the idiom we were looking for, we *could* transform this loop,
2743  // but is it profitable to transform?
2744
2745  BasicBlock *LoopHeaderBB = CurLoop->getHeader();
2746  BasicBlock *LoopPreheaderBB = CurLoop->getLoopPreheader();
2747  assert(LoopPreheaderBB && "There is always a loop preheader.");
2748
2749  BasicBlock *SuccessorBB = CurLoop->getExitBlock();
2750  assert(SuccessorBB && "There is only a single successor.");
2751
2752  IRBuilder<> Builder(LoopPreheaderBB->getTerminator());
2753  Builder.SetCurrentDebugLocation(IV->getDebugLoc());
2754
2755  Type *Ty = Val->getType();
2756  unsigned Bitwidth = Ty->getScalarSizeInBits();
2757
2758  TargetTransformInfo::TargetCostKind CostKind =
2759      TargetTransformInfo::TCK_SizeAndLatency;
2760
2761  // The rewrite is considered to be unprofitable iff and only iff the
2762  // intrinsic we'll use are not cheap. Note that we are okay with *just*
2763  // making the loop countable, even if nothing else changes.
2764  IntrinsicCostAttributes Attrs(
2765      IntrID, Ty, {PoisonValue::get(Ty), /*is_zero_poison=*/Builder.getFalse()});
2766  InstructionCost Cost = TTI->getIntrinsicInstrCost(Attrs, CostKind);
2767  if (Cost > TargetTransformInfo::TCC_Basic) {
2768    LLVM_DEBUG(dbgs() << DEBUG_TYPE
2769               " Intrinsic is too costly, not beneficial\n");
2770    return MadeChange;
2771  }
2772
2773  // Ok, transform appears worthwhile.
2774  MadeChange = true;
2775
2776  bool OffsetIsZero = false;
2777  if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
2778    OffsetIsZero = ExtraOffsetExprC->isZero();
2779
2780  // Step 1: Compute the loop's final IV value / trip count.
2781
2782  CallInst *ValNumLeadingZeros = Builder.CreateIntrinsic(
2783      IntrID, Ty, {Val, /*is_zero_poison=*/Builder.getFalse()},
2784      /*FMFSource=*/nullptr, Val->getName() + ".numleadingzeros");
2785  Value *ValNumActiveBits = Builder.CreateSub(
2786      ConstantInt::get(Ty, Ty->getScalarSizeInBits()), ValNumLeadingZeros,
2787      Val->getName() + ".numactivebits", /*HasNUW=*/true,
2788      /*HasNSW=*/Bitwidth != 2);
2789
2790  SCEVExpander Expander(*SE, *DL, "loop-idiom");
2791  Expander.setInsertPoint(&*Builder.GetInsertPoint());
2792  Value *ExtraOffset = Expander.expandCodeFor(ExtraOffsetExpr);
2793
2794  Value *ValNumActiveBitsOffset = Builder.CreateAdd(
2795      ValNumActiveBits, ExtraOffset, ValNumActiveBits->getName() + ".offset",
2796      /*HasNUW=*/OffsetIsZero, /*HasNSW=*/true);
2797  Value *IVFinal = Builder.CreateIntrinsic(Intrinsic::smax, {Ty},
2798                                           {ValNumActiveBitsOffset, Start},
2799                                           /*FMFSource=*/nullptr, "iv.final");
2800
2801  auto *LoopBackedgeTakenCount = cast<Instruction>(Builder.CreateSub(
2802      IVFinal, Start, CurLoop->getName() + ".backedgetakencount",
2803      /*HasNUW=*/OffsetIsZero, /*HasNSW=*/true));
2804  // FIXME: or when the offset was `add nuw`
2805
2806  // We know loop's backedge-taken count, but what's loop's trip count?
2807  Value *LoopTripCount =
2808      Builder.CreateAdd(LoopBackedgeTakenCount, ConstantInt::get(Ty, 1),
2809                        CurLoop->getName() + ".tripcount", /*HasNUW=*/true,
2810                        /*HasNSW=*/Bitwidth != 2);
2811
2812  // Step 2: Adjust the successor basic block to recieve the original
2813  //         induction variable's final value instead of the orig. IV itself.
2814
2815  IV->replaceUsesOutsideBlock(IVFinal, LoopHeaderBB);
2816
2817  // Step 3: Rewrite the loop into a countable form, with canonical IV.
2818
2819  // The new canonical induction variable.
2820  Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->begin());
2821  auto *CIV = Builder.CreatePHI(Ty, 2, CurLoop->getName() + ".iv");
2822
2823  // The induction itself.
2824  Builder.SetInsertPoint(LoopHeaderBB, LoopHeaderBB->getFirstNonPHIIt());
2825  auto *CIVNext =
2826      Builder.CreateAdd(CIV, ConstantInt::get(Ty, 1), CIV->getName() + ".next",
2827                        /*HasNUW=*/true, /*HasNSW=*/Bitwidth != 2);
2828
2829  // The loop trip count check.
2830  auto *CIVCheck = Builder.CreateICmpEQ(CIVNext, LoopTripCount,
2831                                        CurLoop->getName() + ".ivcheck");
2832  auto *NewIVCheck = CIVCheck;
2833  if (InvertedCond) {
2834    NewIVCheck = Builder.CreateNot(CIVCheck);
2835    NewIVCheck->takeName(ValShiftedIsZero);
2836  }
2837
2838  // The original IV, but rebased to be an offset to the CIV.
2839  auto *IVDePHId = Builder.CreateAdd(CIV, Start, "", /*HasNUW=*/false,
2840                                     /*HasNSW=*/true); // FIXME: what about NUW?
2841  IVDePHId->takeName(IV);
2842
2843  // The loop terminator.
2844  Builder.SetInsertPoint(LoopHeaderBB->getTerminator());
2845  Builder.CreateCondBr(CIVCheck, SuccessorBB, LoopHeaderBB);
2846  LoopHeaderBB->getTerminator()->eraseFromParent();
2847
2848  // Populate the IV PHI.
2849  CIV->addIncoming(ConstantInt::get(Ty, 0), LoopPreheaderBB);
2850  CIV->addIncoming(CIVNext, LoopHeaderBB);
2851
2852  // Step 4: Forget the "non-computable" trip-count SCEV associated with the
2853  //   loop. The loop would otherwise not be deleted even if it becomes empty.
2854
2855  SE->forgetLoop(CurLoop);
2856
2857  // Step 5: Try to cleanup the loop's body somewhat.
2858  IV->replaceAllUsesWith(IVDePHId);
2859  IV->eraseFromParent();
2860
2861  ValShiftedIsZero->replaceAllUsesWith(NewIVCheck);
2862  ValShiftedIsZero->eraseFromParent();
2863
2864  // Other passes will take care of actually deleting the loop if possible.
2865
2866  LLVM_DEBUG(dbgs() << DEBUG_TYPE " shift-until-zero idiom optimized!\n");
2867
2868  ++NumShiftUntilZero;
2869  return MadeChange;
2870}
2871