MVEGatherScatterLowering.cpp revision 360784
1//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This pass custom lowers llvm.gather and llvm.scatter instructions to
10/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11/// produce a better final result as we go.
12//
13//===----------------------------------------------------------------------===//
14
15#include "ARM.h"
16#include "ARMBaseInstrInfo.h"
17#include "ARMSubtarget.h"
18#include "llvm/Analysis/TargetTransformInfo.h"
19#include "llvm/CodeGen/TargetLowering.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/CodeGen/TargetSubtargetInfo.h"
22#include "llvm/InitializePasses.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Constant.h"
25#include "llvm/IR/Constants.h"
26#include "llvm/IR/DerivedTypes.h"
27#include "llvm/IR/Function.h"
28#include "llvm/IR/InstrTypes.h"
29#include "llvm/IR/Instruction.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/Intrinsics.h"
33#include "llvm/IR/IntrinsicsARM.h"
34#include "llvm/IR/IRBuilder.h"
35#include "llvm/IR/PatternMatch.h"
36#include "llvm/IR/Type.h"
37#include "llvm/IR/Value.h"
38#include "llvm/Pass.h"
39#include "llvm/Support/Casting.h"
40#include <algorithm>
41#include <cassert>
42
43using namespace llvm;
44
45#define DEBUG_TYPE "mve-gather-scatter-lowering"
46
47cl::opt<bool> EnableMaskedGatherScatters(
48    "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
49    cl::desc("Enable the generation of masked gathers and scatters"));
50
51namespace {
52
53class MVEGatherScatterLowering : public FunctionPass {
54public:
55  static char ID; // Pass identification, replacement for typeid
56
57  explicit MVEGatherScatterLowering() : FunctionPass(ID) {
58    initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
59  }
60
61  bool runOnFunction(Function &F) override;
62
63  StringRef getPassName() const override {
64    return "MVE gather/scatter lowering";
65  }
66
67  void getAnalysisUsage(AnalysisUsage &AU) const override {
68    AU.setPreservesCFG();
69    AU.addRequired<TargetPassConfig>();
70    FunctionPass::getAnalysisUsage(AU);
71  }
72
73private:
74  // Check this is a valid gather with correct alignment
75  bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
76                               unsigned Alignment);
77  // Check whether Ptr is hidden behind a bitcast and look through it
78  void lookThroughBitcast(Value *&Ptr);
79  // Check for a getelementptr and deduce base and offsets from it, on success
80  // returning the base directly and the offsets indirectly using the Offsets
81  // argument
82  Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder);
83
84  bool lowerGather(IntrinsicInst *I);
85  // Create a gather from a base + vector of offsets
86  Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
87                                     IRBuilder<> Builder);
88  // Create a gather from a vector of pointers
89  Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
90                                   IRBuilder<> Builder);
91};
92
93} // end anonymous namespace
94
95char MVEGatherScatterLowering::ID = 0;
96
97INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
98                "MVE gather/scattering lowering pass", false, false)
99
100Pass *llvm::createMVEGatherScatterLoweringPass() {
101  return new MVEGatherScatterLowering();
102}
103
104bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
105                                                       unsigned ElemSize,
106                                                       unsigned Alignment) {
107  // Do only allow non-extending gathers for now
108  if (((NumElements == 4 && ElemSize == 32) ||
109       (NumElements == 8 && ElemSize == 16) ||
110       (NumElements == 16 && ElemSize == 8)) &&
111      ElemSize / 8 <= Alignment)
112    return true;
113  LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid "
114                    << "alignment or vector type \n");
115  return false;
116}
117
118Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
119                                          IRBuilder<> Builder) {
120  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
121  if (!GEP) {
122    LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n");
123    return nullptr;
124  }
125  LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading"
126                    << " from base + vector of offsets\n");
127  Value *GEPPtr = GEP->getPointerOperand();
128  if (GEPPtr->getType()->isVectorTy()) {
129    LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers"
130                      << " hidden behind a getelementptr currently not"
131                      << " supported. Expanding.\n");
132    return nullptr;
133  }
134  if (GEP->getNumOperands() != 2) {
135    LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many"
136                      << " operands. Expanding.\n");
137    return nullptr;
138  }
139  Offsets = GEP->getOperand(1);
140  // SExt offsets inside masked gathers are not permitted by the architecture;
141  // we therefore can't fold them
142  if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
143    Offsets = ZextOffs->getOperand(0);
144  Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty));
145  // If the offset we found does not have the type the intrinsic expects,
146  // i.e., the same type as the gather itself, we need to convert it (only i
147  // types) or fall back to expanding the gather
148  if (OffsType != Offsets->getType()) {
149    if (OffsType->getScalarSizeInBits() >
150        Offsets->getType()->getScalarSizeInBits()) {
151      LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n");
152      Offsets = Builder.CreateZExt(Offsets, OffsType, "");
153    } else {
154      LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't"
155                        << " create masked gather\n");
156      return nullptr;
157    }
158  }
159  // If none of the checks failed, return the gep's base pointer
160  return GEPPtr;
161}
162
163void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
164  // Look through bitcast instruction if #elements is the same
165  if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
166    Type *BCTy = BitCast->getType();
167    Type *BCSrcTy = BitCast->getOperand(0)->getType();
168    if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) {
169      LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n");
170      Ptr = BitCast->getOperand(0);
171    }
172  }
173}
174
175bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
176  using namespace PatternMatch;
177  LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
178
179  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
180  // Attempt to turn the masked gather in I into a MVE intrinsic
181  // Potentially optimising the addressing modes as we do so.
182  Type *Ty = I->getType();
183  Value *Ptr = I->getArgOperand(0);
184  unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
185  Value *Mask = I->getArgOperand(2);
186  Value *PassThru = I->getArgOperand(3);
187
188  if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(),
189                               Ty->getScalarSizeInBits(), Alignment))
190    return false;
191  lookThroughBitcast(Ptr);
192  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
193
194  IRBuilder<> Builder(I->getContext());
195  Builder.SetInsertPoint(I);
196  Builder.SetCurrentDebugLocation(I->getDebugLoc());
197  Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder);
198  if (!Load)
199    Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
200  if (!Load)
201    return false;
202
203  if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
204    LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
205                      << "creating select\n");
206    Load = Builder.CreateSelect(Mask, Load, PassThru);
207  }
208
209  LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
210  I->replaceAllUsesWith(Load);
211  I->eraseFromParent();
212  return true;
213}
214
215Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
216    IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
217  using namespace PatternMatch;
218  LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
219  Type *Ty = I->getType();
220  if (Ty->getVectorNumElements() != 4)
221    // Can't build an intrinsic for this
222    return nullptr;
223  Value *Mask = I->getArgOperand(2);
224  if (match(Mask, m_One()))
225    return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
226                                   {Ty, Ptr->getType()},
227                                   {Ptr, Builder.getInt32(0)});
228  else
229    return Builder.CreateIntrinsic(
230        Intrinsic::arm_mve_vldr_gather_base_predicated,
231        {Ty, Ptr->getType(), Mask->getType()},
232        {Ptr, Builder.getInt32(0), Mask});
233}
234
235Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
236    IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
237  using namespace PatternMatch;
238  Type *Ty = I->getType();
239  Value *Offsets;
240  Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder);
241  if (!BasePtr)
242    return nullptr;
243
244  unsigned Scale;
245  int GEPElemSize =
246      BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits();
247  int ResultElemSize = Ty->getScalarSizeInBits();
248  // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a
249  // 8bit, 16bit or 32bit load scaled by 1
250  if (GEPElemSize == 32 && ResultElemSize == 32) {
251    Scale = 2;
252  } else if (GEPElemSize == 16 && ResultElemSize == 16) {
253    Scale = 1;
254  } else if (GEPElemSize == 8) {
255    Scale = 0;
256  } else {
257    LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't"
258                      << " create masked gather\n");
259    return nullptr;
260  }
261
262  Value *Mask = I->getArgOperand(2);
263  if (!match(Mask, m_One()))
264    return Builder.CreateIntrinsic(
265        Intrinsic::arm_mve_vldr_gather_offset_predicated,
266        {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()},
267        {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
268         Builder.getInt32(Scale), Builder.getInt32(1), Mask});
269  else
270    return Builder.CreateIntrinsic(
271        Intrinsic::arm_mve_vldr_gather_offset,
272        {Ty, BasePtr->getType(), Offsets->getType()},
273        {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
274         Builder.getInt32(Scale), Builder.getInt32(1)});
275}
276
277bool MVEGatherScatterLowering::runOnFunction(Function &F) {
278  if (!EnableMaskedGatherScatters)
279    return false;
280  auto &TPC = getAnalysis<TargetPassConfig>();
281  auto &TM = TPC.getTM<TargetMachine>();
282  auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
283  if (!ST->hasMVEIntegerOps())
284    return false;
285  SmallVector<IntrinsicInst *, 4> Gathers;
286  for (BasicBlock &BB : F) {
287    for (Instruction &I : BB) {
288      IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
289      if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
290        Gathers.push_back(II);
291    }
292  }
293
294  if (Gathers.empty())
295    return false;
296
297  for (IntrinsicInst *I : Gathers)
298    lowerGather(I);
299
300  return true;
301}
302