1//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Identification:
10// This step is responsible for finding the patterns that can be lowered to
11// complex instructions, and building a graph to represent the complex
12// structures. Starting from the "Converging Shuffle" (a shuffle that
13// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14// operands are evaluated and identified as "Composite Nodes" (collections of
15// instructions that can potentially be lowered to a single complex
16// instruction). This is performed by checking the real and imaginary components
17// and tracking the data flow for each component while following the operand
18// pairs. Validity of each node is expected to be done upon creation, and any
19// validation errors should halt traversal and prevent further graph
20// construction.
21// Instead of relying on Shuffle operations, vector interleaving and
22// deinterleaving can be represented by vector.interleave2 and
23// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24// these intrinsics, whereas, fixed-width vectors are recognized for both
25// shufflevector instruction and intrinsics.
26//
27// Replacement:
28// This step traverses the graph built up by identification, delegating to the
29// target to validate and generate the correct intrinsics, and plumbs them
30// together connecting each end of the new intrinsics graph to the existing
31// use-def chain. This step is assumed to finish successfully, as all
32// information is expected to be correct by this point.
33//
34//
35// Internal data structure:
36// ComplexDeinterleavingGraph:
37// Keeps references to all the valid CompositeNodes formed as part of the
38// transformation, and every Instruction contained within said nodes. It also
39// holds onto a reference to the root Instruction, and the root node that should
40// replace it.
41//
42// ComplexDeinterleavingCompositeNode:
43// A CompositeNode represents a single transformation point; each node should
44// transform into a single complex instruction (ignoring vector splitting, which
45// would generate more instructions per node). They are identified in a
46// depth-first manner, traversing and identifying the operands of each
47// instruction in the order they appear in the IR.
48// Each node maintains a reference  to its Real and Imaginary instructions,
49// as well as any additional instructions that make up the identified operation
50// (Internal instructions should only have uses within their containing node).
51// A Node also contains the rotation and operation type that it represents.
52// Operands contains pointers to other CompositeNodes, acting as the edges in
53// the graph. ReplacementValue is the transformed Value* that has been emitted
54// to the IR.
55//
56// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58// should be pre-populated.
59//
60//===----------------------------------------------------------------------===//
61
62#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63#include "llvm/ADT/MapVector.h"
64#include "llvm/ADT/Statistic.h"
65#include "llvm/Analysis/TargetLibraryInfo.h"
66#include "llvm/Analysis/TargetTransformInfo.h"
67#include "llvm/CodeGen/TargetLowering.h"
68#include "llvm/CodeGen/TargetPassConfig.h"
69#include "llvm/CodeGen/TargetSubtargetInfo.h"
70#include "llvm/IR/IRBuilder.h"
71#include "llvm/IR/PatternMatch.h"
72#include "llvm/InitializePasses.h"
73#include "llvm/Target/TargetMachine.h"
74#include "llvm/Transforms/Utils/Local.h"
75#include <algorithm>
76
77using namespace llvm;
78using namespace PatternMatch;
79
80#define DEBUG_TYPE "complex-deinterleaving"
81
82STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83
84static cl::opt<bool> ComplexDeinterleavingEnabled(
85    "enable-complex-deinterleaving",
86    cl::desc("Enable generation of complex instructions"), cl::init(true),
87    cl::Hidden);
88
89/// Checks the given mask, and determines whether said mask is interleaving.
90///
91/// To be interleaving, a mask must alternate between `i` and `i + (Length /
92/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93/// 4x vector interleaving mask would be <0, 2, 1, 3>).
94static bool isInterleavingMask(ArrayRef<int> Mask);
95
96/// Checks the given mask, and determines whether said mask is deinterleaving.
97///
98/// To be deinterleaving, a mask must increment in steps of 2, and either start
99/// with 0 or 1.
100/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101/// <1, 3, 5, 7>).
102static bool isDeinterleavingMask(ArrayRef<int> Mask);
103
104/// Returns true if the operation is a negation of V, and it works for both
105/// integers and floats.
106static bool isNeg(Value *V);
107
108/// Returns the operand for negation operation.
109static Value *getNegOperand(Value *V);
110
111namespace {
112
113class ComplexDeinterleavingLegacyPass : public FunctionPass {
114public:
115  static char ID;
116
117  ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118      : FunctionPass(ID), TM(TM) {
119    initializeComplexDeinterleavingLegacyPassPass(
120        *PassRegistry::getPassRegistry());
121  }
122
123  StringRef getPassName() const override {
124    return "Complex Deinterleaving Pass";
125  }
126
127  bool runOnFunction(Function &F) override;
128  void getAnalysisUsage(AnalysisUsage &AU) const override {
129    AU.addRequired<TargetLibraryInfoWrapperPass>();
130    AU.setPreservesCFG();
131  }
132
133private:
134  const TargetMachine *TM;
135};
136
137class ComplexDeinterleavingGraph;
138struct ComplexDeinterleavingCompositeNode {
139
140  ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141                                     Value *R, Value *I)
142      : Operation(Op), Real(R), Imag(I) {}
143
144private:
145  friend class ComplexDeinterleavingGraph;
146  using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147  using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148
149public:
150  ComplexDeinterleavingOperation Operation;
151  Value *Real;
152  Value *Imag;
153
154  // This two members are required exclusively for generating
155  // ComplexDeinterleavingOperation::Symmetric operations.
156  unsigned Opcode;
157  std::optional<FastMathFlags> Flags;
158
159  ComplexDeinterleavingRotation Rotation =
160      ComplexDeinterleavingRotation::Rotation_0;
161  SmallVector<RawNodePtr> Operands;
162  Value *ReplacementNode = nullptr;
163
164  void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165
166  void dump() { dump(dbgs()); }
167  void dump(raw_ostream &OS) {
168    auto PrintValue = [&](Value *V) {
169      if (V) {
170        OS << "\"";
171        V->print(OS, true);
172        OS << "\"\n";
173      } else
174        OS << "nullptr\n";
175    };
176    auto PrintNodeRef = [&](RawNodePtr Ptr) {
177      if (Ptr)
178        OS << Ptr << "\n";
179      else
180        OS << "nullptr\n";
181    };
182
183    OS << "- CompositeNode: " << this << "\n";
184    OS << "  Real: ";
185    PrintValue(Real);
186    OS << "  Imag: ";
187    PrintValue(Imag);
188    OS << "  ReplacementNode: ";
189    PrintValue(ReplacementNode);
190    OS << "  Operation: " << (int)Operation << "\n";
191    OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
192    OS << "  Operands: \n";
193    for (const auto &Op : Operands) {
194      OS << "    - ";
195      PrintNodeRef(Op);
196    }
197  }
198};
199
200class ComplexDeinterleavingGraph {
201public:
202  struct Product {
203    Value *Multiplier;
204    Value *Multiplicand;
205    bool IsPositive;
206  };
207
208  using Addend = std::pair<Value *, bool>;
209  using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210  using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211
212  // Helper struct for holding info about potential partial multiplication
213  // candidates
214  struct PartialMulCandidate {
215    Value *Common;
216    NodePtr Node;
217    unsigned RealIdx;
218    unsigned ImagIdx;
219    bool IsNodeInverted;
220  };
221
222  explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223                                      const TargetLibraryInfo *TLI)
224      : TL(TL), TLI(TLI) {}
225
226private:
227  const TargetLowering *TL = nullptr;
228  const TargetLibraryInfo *TLI = nullptr;
229  SmallVector<NodePtr> CompositeNodes;
230  DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231
232  SmallPtrSet<Instruction *, 16> FinalInstructions;
233
234  /// Root instructions are instructions from which complex computation starts
235  std::map<Instruction *, NodePtr> RootToNode;
236
237  /// Topologically sorted root instructions
238  SmallVector<Instruction *, 1> OrderedRoots;
239
240  /// When examining a basic block for complex deinterleaving, if it is a simple
241  /// one-block loop, then the only incoming block is 'Incoming' and the
242  /// 'BackEdge' block is the block itself."
243  BasicBlock *BackEdge = nullptr;
244  BasicBlock *Incoming = nullptr;
245
246  /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247  /// %OutsideUser as it is shown in the IR:
248  ///
249  /// vector.body:
250  ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251  ///                                [ %ReductionOp, %vector.body ]
252  ///   ...
253  ///   %ReductionOp = fadd i64 ...
254  ///   ...
255  ///   br i1 %condition, label %vector.body, %middle.block
256  ///
257  /// middle.block:
258  ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259  ///
260  /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261  /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262  MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263
264  /// In the process of detecting a reduction, we consider a pair of
265  /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266  /// traverse the use-tree to detect complex operations. As this is a reduction
267  /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268  /// to the %ReductionOPs that we suspect to be complex.
269  /// RealPHI and ImagPHI are used by the identifyPHINode method.
270  PHINode *RealPHI = nullptr;
271  PHINode *ImagPHI = nullptr;
272
273  /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274  /// detection.
275  bool PHIsFound = false;
276
277  /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278  /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279  /// This mapping is populated during
280  /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281  /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282  /// replacement process.
283  std::map<PHINode *, PHINode *> OldToNewPHI;
284
285  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286                               Value *R, Value *I) {
287    assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288             Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289            (R && I)) &&
290           "Reduction related nodes must have Real and Imaginary parts");
291    return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292                                                                I);
293  }
294
295  NodePtr submitCompositeNode(NodePtr Node) {
296    CompositeNodes.push_back(Node);
297    if (Node->Real && Node->Imag)
298      CachedResult[{Node->Real, Node->Imag}] = Node;
299    return Node;
300  }
301
302  /// Identifies a complex partial multiply pattern and its rotation, based on
303  /// the following patterns
304  ///
305  ///  0:  r: cr + ar * br
306  ///      i: ci + ar * bi
307  /// 90:  r: cr - ai * bi
308  ///      i: ci + ai * br
309  /// 180: r: cr - ar * br
310  ///      i: ci - ar * bi
311  /// 270: r: cr + ai * bi
312  ///      i: ci - ai * br
313  NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314
315  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316  /// is partially known from identifyPartialMul, filling in the other half of
317  /// the complex pair.
318  NodePtr
319  identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320                              std::pair<Value *, Value *> &CommonOperandI);
321
322  /// Identifies a complex add pattern and its rotation, based on the following
323  /// patterns.
324  ///
325  /// 90:  r: ar - bi
326  ///      i: ai + br
327  /// 270: r: ar + bi
328  ///      i: ai - br
329  NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330  NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331
332  NodePtr identifyNode(Value *R, Value *I);
333
334  /// Determine if a sum of complex numbers can be formed from \p RealAddends
335  /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336  /// Return nullptr if it is not possible to construct a complex number.
337  /// \p Flags are needed to generate symmetric Add and Sub operations.
338  NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339                            std::list<Addend> &ImagAddends,
340                            std::optional<FastMathFlags> Flags,
341                            NodePtr Accumulator);
342
343  /// Extract one addend that have both real and imaginary parts positive.
344  NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345                                std::list<Addend> &ImagAddends);
346
347  /// Determine if sum of multiplications of complex numbers can be formed from
348  /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349  /// to it. Return nullptr if it is not possible to construct a complex number.
350  NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351                                  std::vector<Product> &ImagMuls,
352                                  NodePtr Accumulator);
353
354  /// Go through pairs of multiplication (one Real and one Imag) and find all
355  /// possible candidates for partial multiplication and put them into \p
356  /// Candidates. Returns true if all Product has pair with common operand
357  bool collectPartialMuls(const std::vector<Product> &RealMuls,
358                          const std::vector<Product> &ImagMuls,
359                          std::vector<PartialMulCandidate> &Candidates);
360
361  /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362  /// the order of complex computation operations may be significantly altered,
363  /// and the real and imaginary parts may not be executed in parallel. This
364  /// function takes this into consideration and employs a more general approach
365  /// to identify complex computations. Initially, it gathers all the addends
366  /// and multiplicands and then constructs a complex expression from them.
367  NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368
369  NodePtr identifyRoot(Instruction *I);
370
371  /// Identifies the Deinterleave operation applied to a vector containing
372  /// complex numbers. There are two ways to represent the Deinterleave
373  /// operation:
374  /// * Using two shufflevectors with even indices for /pReal instruction and
375  /// odd indices for /pImag instructions (only for fixed-width vectors)
376  /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377  /// intrinsic (for both fixed and scalable vectors)
378  NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379
380  /// identifying the operation that represents a complex number repeated in a
381  /// Splat vector. There are two possible types of splats: ConstantExpr with
382  /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383  /// initialization mask with all values set to zero.
384  NodePtr identifySplat(Value *Real, Value *Imag);
385
386  NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387
388  /// Identifies SelectInsts in a loop that has reduction with predication masks
389  /// and/or predicated tail folding
390  NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391
392  Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393
394  /// Complete IR modifications after producing new reduction operation:
395  /// * Populate the PHINode generated for
396  /// ComplexDeinterleavingOperation::ReductionPHI
397  /// * Deinterleave the final value outside of the loop and repurpose original
398  /// reduction users
399  void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400
401public:
402  void dump() { dump(dbgs()); }
403  void dump(raw_ostream &OS) {
404    for (const auto &Node : CompositeNodes)
405      Node->dump(OS);
406  }
407
408  /// Returns false if the deinterleaving operation should be cancelled for the
409  /// current graph.
410  bool identifyNodes(Instruction *RootI);
411
412  /// In case \pB is one-block loop, this function seeks potential reductions
413  /// and populates ReductionInfo. Returns true if any reductions were
414  /// identified.
415  bool collectPotentialReductions(BasicBlock *B);
416
417  void identifyReductionNodes();
418
419  /// Check that every instruction, from the roots to the leaves, has internal
420  /// uses.
421  bool checkNodes();
422
423  /// Perform the actual replacement of the underlying instruction graph.
424  void replaceNodes();
425};
426
427class ComplexDeinterleaving {
428public:
429  ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430      : TL(tl), TLI(tli) {}
431  bool runOnFunction(Function &F);
432
433private:
434  bool evaluateBasicBlock(BasicBlock *B);
435
436  const TargetLowering *TL = nullptr;
437  const TargetLibraryInfo *TLI = nullptr;
438};
439
440} // namespace
441
442char ComplexDeinterleavingLegacyPass::ID = 0;
443
444INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445                      "Complex Deinterleaving", false, false)
446INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447                    "Complex Deinterleaving", false, false)
448
449PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450                                                 FunctionAnalysisManager &AM) {
451  const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452  auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453  if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454    return PreservedAnalyses::all();
455
456  PreservedAnalyses PA;
457  PA.preserve<FunctionAnalysisManagerModuleProxy>();
458  return PA;
459}
460
461FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462  return new ComplexDeinterleavingLegacyPass(TM);
463}
464
465bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466  const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467  auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468  return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469}
470
471bool ComplexDeinterleaving::runOnFunction(Function &F) {
472  if (!ComplexDeinterleavingEnabled) {
473    LLVM_DEBUG(
474        dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475    return false;
476  }
477
478  if (!TL->isComplexDeinterleavingSupported()) {
479    LLVM_DEBUG(
480        dbgs() << "Complex deinterleaving has been disabled, target does "
481                  "not support lowering of complex number operations.\n");
482    return false;
483  }
484
485  bool Changed = false;
486  for (auto &B : F)
487    Changed |= evaluateBasicBlock(&B);
488
489  return Changed;
490}
491
492static bool isInterleavingMask(ArrayRef<int> Mask) {
493  // If the size is not even, it's not an interleaving mask
494  if ((Mask.size() & 1))
495    return false;
496
497  int HalfNumElements = Mask.size() / 2;
498  for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499    int MaskIdx = Idx * 2;
500    if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501      return false;
502  }
503
504  return true;
505}
506
507static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508  int Offset = Mask[0];
509  int HalfNumElements = Mask.size() / 2;
510
511  for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512    if (Mask[Idx] != (Idx * 2) + Offset)
513      return false;
514  }
515
516  return true;
517}
518
519bool isNeg(Value *V) {
520  return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521}
522
523Value *getNegOperand(Value *V) {
524  assert(isNeg(V));
525  auto *I = cast<Instruction>(V);
526  if (I->getOpcode() == Instruction::FNeg)
527    return I->getOperand(0);
528
529  return I->getOperand(1);
530}
531
532bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533  ComplexDeinterleavingGraph Graph(TL, TLI);
534  if (Graph.collectPotentialReductions(B))
535    Graph.identifyReductionNodes();
536
537  for (auto &I : *B)
538    Graph.identifyNodes(&I);
539
540  if (Graph.checkNodes()) {
541    Graph.replaceNodes();
542    return true;
543  }
544
545  return false;
546}
547
548ComplexDeinterleavingGraph::NodePtr
549ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550    Instruction *Real, Instruction *Imag,
551    std::pair<Value *, Value *> &PartialMatch) {
552  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553                    << "\n");
554
555  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556    LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
557    return nullptr;
558  }
559
560  if ((Real->getOpcode() != Instruction::FMul &&
561       Real->getOpcode() != Instruction::Mul) ||
562      (Imag->getOpcode() != Instruction::FMul &&
563       Imag->getOpcode() != Instruction::Mul)) {
564    LLVM_DEBUG(
565        dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
566    return nullptr;
567  }
568
569  Value *R0 = Real->getOperand(0);
570  Value *R1 = Real->getOperand(1);
571  Value *I0 = Imag->getOperand(0);
572  Value *I1 = Imag->getOperand(1);
573
574  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575  // rotations and use the operand.
576  unsigned Negs = 0;
577  Value *Op;
578  if (match(R0, m_Neg(m_Value(Op)))) {
579    Negs |= 1;
580    R0 = Op;
581  } else if (match(R1, m_Neg(m_Value(Op)))) {
582    Negs |= 1;
583    R1 = Op;
584  }
585
586  if (isNeg(I0)) {
587    Negs |= 2;
588    Negs ^= 1;
589    I0 = Op;
590  } else if (match(I1, m_Neg(m_Value(Op)))) {
591    Negs |= 2;
592    Negs ^= 1;
593    I1 = Op;
594  }
595
596  ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597
598  Value *CommonOperand;
599  Value *UncommonRealOp;
600  Value *UncommonImagOp;
601
602  if (R0 == I0 || R0 == I1) {
603    CommonOperand = R0;
604    UncommonRealOp = R1;
605  } else if (R1 == I0 || R1 == I1) {
606    CommonOperand = R1;
607    UncommonRealOp = R0;
608  } else {
609    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
610    return nullptr;
611  }
612
613  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615      Rotation == ComplexDeinterleavingRotation::Rotation_270)
616    std::swap(UncommonRealOp, UncommonImagOp);
617
618  // Between identifyPartialMul and here we need to have found a complete valid
619  // pair from the CommonOperand of each part.
620  if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621      Rotation == ComplexDeinterleavingRotation::Rotation_180)
622    PartialMatch.first = CommonOperand;
623  else
624    PartialMatch.second = CommonOperand;
625
626  if (!PartialMatch.first || !PartialMatch.second) {
627    LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
628    return nullptr;
629  }
630
631  NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632  if (!CommonNode) {
633    LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
634    return nullptr;
635  }
636
637  NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638  if (!UncommonNode) {
639    LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
640    return nullptr;
641  }
642
643  NodePtr Node = prepareCompositeNode(
644      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645  Node->Rotation = Rotation;
646  Node->addOperand(CommonNode);
647  Node->addOperand(UncommonNode);
648  return submitCompositeNode(Node);
649}
650
651ComplexDeinterleavingGraph::NodePtr
652ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653                                               Instruction *Imag) {
654  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655                    << "\n");
656  // Determine rotation
657  auto IsAdd = [](unsigned Op) {
658    return Op == Instruction::FAdd || Op == Instruction::Add;
659  };
660  auto IsSub = [](unsigned Op) {
661    return Op == Instruction::FSub || Op == Instruction::Sub;
662  };
663  ComplexDeinterleavingRotation Rotation;
664  if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665    Rotation = ComplexDeinterleavingRotation::Rotation_0;
666  else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667    Rotation = ComplexDeinterleavingRotation::Rotation_90;
668  else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669    Rotation = ComplexDeinterleavingRotation::Rotation_180;
670  else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671    Rotation = ComplexDeinterleavingRotation::Rotation_270;
672  else {
673    LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
674    return nullptr;
675  }
676
677  if (isa<FPMathOperator>(Real) &&
678      (!Real->getFastMathFlags().allowContract() ||
679       !Imag->getFastMathFlags().allowContract())) {
680    LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
681    return nullptr;
682  }
683
684  Value *CR = Real->getOperand(0);
685  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686  if (!RealMulI)
687    return nullptr;
688  Value *CI = Imag->getOperand(0);
689  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690  if (!ImagMulI)
691    return nullptr;
692
693  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694    LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
695    return nullptr;
696  }
697
698  Value *R0 = RealMulI->getOperand(0);
699  Value *R1 = RealMulI->getOperand(1);
700  Value *I0 = ImagMulI->getOperand(0);
701  Value *I1 = ImagMulI->getOperand(1);
702
703  Value *CommonOperand;
704  Value *UncommonRealOp;
705  Value *UncommonImagOp;
706
707  if (R0 == I0 || R0 == I1) {
708    CommonOperand = R0;
709    UncommonRealOp = R1;
710  } else if (R1 == I0 || R1 == I1) {
711    CommonOperand = R1;
712    UncommonRealOp = R0;
713  } else {
714    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
715    return nullptr;
716  }
717
718  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720      Rotation == ComplexDeinterleavingRotation::Rotation_270)
721    std::swap(UncommonRealOp, UncommonImagOp);
722
723  std::pair<Value *, Value *> PartialMatch(
724      (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725       Rotation == ComplexDeinterleavingRotation::Rotation_180)
726          ? CommonOperand
727          : nullptr,
728      (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729       Rotation == ComplexDeinterleavingRotation::Rotation_270)
730          ? CommonOperand
731          : nullptr);
732
733  auto *CRInst = dyn_cast<Instruction>(CR);
734  auto *CIInst = dyn_cast<Instruction>(CI);
735
736  if (!CRInst || !CIInst) {
737    LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
738    return nullptr;
739  }
740
741  NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742  if (!CNode) {
743    LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
744    return nullptr;
745  }
746
747  NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748  if (!UncommonRes) {
749    LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
750    return nullptr;
751  }
752
753  assert(PartialMatch.first && PartialMatch.second);
754  NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755  if (!CommonRes) {
756    LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
757    return nullptr;
758  }
759
760  NodePtr Node = prepareCompositeNode(
761      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762  Node->Rotation = Rotation;
763  Node->addOperand(CommonRes);
764  Node->addOperand(UncommonRes);
765  Node->addOperand(CNode);
766  return submitCompositeNode(Node);
767}
768
769ComplexDeinterleavingGraph::NodePtr
770ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771  LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772
773  // Determine rotation
774  ComplexDeinterleavingRotation Rotation;
775  if ((Real->getOpcode() == Instruction::FSub &&
776       Imag->getOpcode() == Instruction::FAdd) ||
777      (Real->getOpcode() == Instruction::Sub &&
778       Imag->getOpcode() == Instruction::Add))
779    Rotation = ComplexDeinterleavingRotation::Rotation_90;
780  else if ((Real->getOpcode() == Instruction::FAdd &&
781            Imag->getOpcode() == Instruction::FSub) ||
782           (Real->getOpcode() == Instruction::Add &&
783            Imag->getOpcode() == Instruction::Sub))
784    Rotation = ComplexDeinterleavingRotation::Rotation_270;
785  else {
786    LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787    return nullptr;
788  }
789
790  auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791  auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792  auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793  auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794
795  if (!AR || !AI || !BR || !BI) {
796    LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797    return nullptr;
798  }
799
800  NodePtr ResA = identifyNode(AR, AI);
801  if (!ResA) {
802    LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803    return nullptr;
804  }
805  NodePtr ResB = identifyNode(BR, BI);
806  if (!ResB) {
807    LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808    return nullptr;
809  }
810
811  NodePtr Node =
812      prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813  Node->Rotation = Rotation;
814  Node->addOperand(ResA);
815  Node->addOperand(ResB);
816  return submitCompositeNode(Node);
817}
818
819static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820  unsigned OpcA = A->getOpcode();
821  unsigned OpcB = B->getOpcode();
822
823  return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824         (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825         (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826         (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827}
828
829static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830  auto Pattern =
831      m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832
833  return match(A, Pattern) && match(B, Pattern);
834}
835
836static bool isInstructionPotentiallySymmetric(Instruction *I) {
837  switch (I->getOpcode()) {
838  case Instruction::FAdd:
839  case Instruction::FSub:
840  case Instruction::FMul:
841  case Instruction::FNeg:
842  case Instruction::Add:
843  case Instruction::Sub:
844  case Instruction::Mul:
845    return true;
846  default:
847    return false;
848  }
849}
850
851ComplexDeinterleavingGraph::NodePtr
852ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853                                                       Instruction *Imag) {
854  if (Real->getOpcode() != Imag->getOpcode())
855    return nullptr;
856
857  if (!isInstructionPotentiallySymmetric(Real) ||
858      !isInstructionPotentiallySymmetric(Imag))
859    return nullptr;
860
861  auto *R0 = Real->getOperand(0);
862  auto *I0 = Imag->getOperand(0);
863
864  NodePtr Op0 = identifyNode(R0, I0);
865  NodePtr Op1 = nullptr;
866  if (Op0 == nullptr)
867    return nullptr;
868
869  if (Real->isBinaryOp()) {
870    auto *R1 = Real->getOperand(1);
871    auto *I1 = Imag->getOperand(1);
872    Op1 = identifyNode(R1, I1);
873    if (Op1 == nullptr)
874      return nullptr;
875  }
876
877  if (isa<FPMathOperator>(Real) &&
878      Real->getFastMathFlags() != Imag->getFastMathFlags())
879    return nullptr;
880
881  auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882                                   Real, Imag);
883  Node->Opcode = Real->getOpcode();
884  if (isa<FPMathOperator>(Real))
885    Node->Flags = Real->getFastMathFlags();
886
887  Node->addOperand(Op0);
888  if (Real->isBinaryOp())
889    Node->addOperand(Op1);
890
891  return submitCompositeNode(Node);
892}
893
894ComplexDeinterleavingGraph::NodePtr
895ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897  assert(R->getType() == I->getType() &&
898         "Real and imaginary parts should not have different types");
899
900  auto It = CachedResult.find({R, I});
901  if (It != CachedResult.end()) {
902    LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903    return It->second;
904  }
905
906  if (NodePtr CN = identifySplat(R, I))
907    return CN;
908
909  auto *Real = dyn_cast<Instruction>(R);
910  auto *Imag = dyn_cast<Instruction>(I);
911  if (!Real || !Imag)
912    return nullptr;
913
914  if (NodePtr CN = identifyDeinterleave(Real, Imag))
915    return CN;
916
917  if (NodePtr CN = identifyPHINode(Real, Imag))
918    return CN;
919
920  if (NodePtr CN = identifySelectNode(Real, Imag))
921    return CN;
922
923  auto *VTy = cast<VectorType>(Real->getType());
924  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925
926  bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927      ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928  bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929      ComplexDeinterleavingOperation::CAdd, NewVTy);
930
931  if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932    if (NodePtr CN = identifyPartialMul(Real, Imag))
933      return CN;
934  }
935
936  if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937    if (NodePtr CN = identifyAdd(Real, Imag))
938      return CN;
939  }
940
941  if (HasCMulSupport && HasCAddSupport) {
942    if (NodePtr CN = identifyReassocNodes(Real, Imag))
943      return CN;
944  }
945
946  if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947    return CN;
948
949  LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
950  CachedResult[{R, I}] = nullptr;
951  return nullptr;
952}
953
954ComplexDeinterleavingGraph::NodePtr
955ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956                                                 Instruction *Imag) {
957  auto IsOperationSupported = [](unsigned Opcode) -> bool {
958    return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959           Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960           Opcode == Instruction::Sub;
961  };
962
963  if (!IsOperationSupported(Real->getOpcode()) ||
964      !IsOperationSupported(Imag->getOpcode()))
965    return nullptr;
966
967  std::optional<FastMathFlags> Flags;
968  if (isa<FPMathOperator>(Real)) {
969    if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970      LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971                           "not identical\n");
972      return nullptr;
973    }
974
975    Flags = Real->getFastMathFlags();
976    if (!Flags->allowReassoc()) {
977      LLVM_DEBUG(
978          dbgs()
979          << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980      return nullptr;
981    }
982  }
983
984  // Collect multiplications and addend instructions from the given instruction
985  // while traversing it operands. Additionally, verify that all instructions
986  // have the same fast math flags.
987  auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988                          std::list<Addend> &Addends) -> bool {
989    SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990    SmallPtrSet<Value *, 8> Visited;
991    while (!Worklist.empty()) {
992      auto [V, IsPositive] = Worklist.back();
993      Worklist.pop_back();
994      if (!Visited.insert(V).second)
995        continue;
996
997      Instruction *I = dyn_cast<Instruction>(V);
998      if (!I) {
999        Addends.emplace_back(V, IsPositive);
1000        continue;
1001      }
1002
1003      // If an instruction has more than one user, it indicates that it either
1004      // has an external user, which will be later checked by the checkNodes
1005      // function, or it is a subexpression utilized by multiple expressions. In
1006      // the latter case, we will attempt to separately identify the complex
1007      // operation from here in order to create a shared
1008      // ComplexDeinterleavingCompositeNode.
1009      if (I != Insn && I->getNumUses() > 1) {
1010        LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011        Addends.emplace_back(I, IsPositive);
1012        continue;
1013      }
1014      switch (I->getOpcode()) {
1015      case Instruction::FAdd:
1016      case Instruction::Add:
1017        Worklist.emplace_back(I->getOperand(1), IsPositive);
1018        Worklist.emplace_back(I->getOperand(0), IsPositive);
1019        break;
1020      case Instruction::FSub:
1021        Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022        Worklist.emplace_back(I->getOperand(0), IsPositive);
1023        break;
1024      case Instruction::Sub:
1025        if (isNeg(I)) {
1026          Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027        } else {
1028          Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029          Worklist.emplace_back(I->getOperand(0), IsPositive);
1030        }
1031        break;
1032      case Instruction::FMul:
1033      case Instruction::Mul: {
1034        Value *A, *B;
1035        if (isNeg(I->getOperand(0))) {
1036          A = getNegOperand(I->getOperand(0));
1037          IsPositive = !IsPositive;
1038        } else {
1039          A = I->getOperand(0);
1040        }
1041
1042        if (isNeg(I->getOperand(1))) {
1043          B = getNegOperand(I->getOperand(1));
1044          IsPositive = !IsPositive;
1045        } else {
1046          B = I->getOperand(1);
1047        }
1048        Muls.push_back(Product{A, B, IsPositive});
1049        break;
1050      }
1051      case Instruction::FNeg:
1052        Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053        break;
1054      default:
1055        Addends.emplace_back(I, IsPositive);
1056        continue;
1057      }
1058
1059      if (Flags && I->getFastMathFlags() != *Flags) {
1060        LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061                             "inconsistent with the root instructions' flags: "
1062                          << *I << "\n");
1063        return false;
1064      }
1065    }
1066    return true;
1067  };
1068
1069  std::vector<Product> RealMuls, ImagMuls;
1070  std::list<Addend> RealAddends, ImagAddends;
1071  if (!Collect(Real, RealMuls, RealAddends) ||
1072      !Collect(Imag, ImagMuls, ImagAddends))
1073    return nullptr;
1074
1075  if (RealAddends.size() != ImagAddends.size())
1076    return nullptr;
1077
1078  NodePtr FinalNode;
1079  if (!RealMuls.empty() || !ImagMuls.empty()) {
1080    // If there are multiplicands, extract positive addend and use it as an
1081    // accumulator
1082    FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084    if (!FinalNode)
1085      return nullptr;
1086  }
1087
1088  // Identify and process remaining additions
1089  if (!RealAddends.empty() || !ImagAddends.empty()) {
1090    FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091    if (!FinalNode)
1092      return nullptr;
1093  }
1094  assert(FinalNode && "FinalNode can not be nullptr here");
1095  // Set the Real and Imag fields of the final node and submit it
1096  FinalNode->Real = Real;
1097  FinalNode->Imag = Imag;
1098  submitCompositeNode(FinalNode);
1099  return FinalNode;
1100}
1101
1102bool ComplexDeinterleavingGraph::collectPartialMuls(
1103    const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104    std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105  // Helper function to extract a common operand from two products
1106  auto FindCommonInstruction = [](const Product &Real,
1107                                  const Product &Imag) -> Value * {
1108    if (Real.Multiplicand == Imag.Multiplicand ||
1109        Real.Multiplicand == Imag.Multiplier)
1110      return Real.Multiplicand;
1111
1112    if (Real.Multiplier == Imag.Multiplicand ||
1113        Real.Multiplier == Imag.Multiplier)
1114      return Real.Multiplier;
1115
1116    return nullptr;
1117  };
1118
1119  // Iterating over real and imaginary multiplications to find common operands
1120  // If a common operand is found, a partial multiplication candidate is created
1121  // and added to the candidates vector The function returns false if no common
1122  // operands are found for any product
1123  for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124    bool FoundCommon = false;
1125    for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126      auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127      if (!Common)
1128        continue;
1129
1130      auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131                                                   : RealMuls[i].Multiplicand;
1132      auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133                                                   : ImagMuls[j].Multiplicand;
1134
1135      auto Node = identifyNode(A, B);
1136      if (Node) {
1137        FoundCommon = true;
1138        PartialMulCandidates.push_back({Common, Node, i, j, false});
1139      }
1140
1141      Node = identifyNode(B, A);
1142      if (Node) {
1143        FoundCommon = true;
1144        PartialMulCandidates.push_back({Common, Node, i, j, true});
1145      }
1146    }
1147    if (!FoundCommon)
1148      return false;
1149  }
1150  return true;
1151}
1152
1153ComplexDeinterleavingGraph::NodePtr
1154ComplexDeinterleavingGraph::identifyMultiplications(
1155    std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156    NodePtr Accumulator = nullptr) {
1157  if (RealMuls.size() != ImagMuls.size())
1158    return nullptr;
1159
1160  std::vector<PartialMulCandidate> Info;
1161  if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162    return nullptr;
1163
1164  // Map to store common instruction to node pointers
1165  std::map<Value *, NodePtr> CommonToNode;
1166  std::vector<bool> Processed(Info.size(), false);
1167  for (unsigned I = 0; I < Info.size(); ++I) {
1168    if (Processed[I])
1169      continue;
1170
1171    PartialMulCandidate &InfoA = Info[I];
1172    for (unsigned J = I + 1; J < Info.size(); ++J) {
1173      if (Processed[J])
1174        continue;
1175
1176      PartialMulCandidate &InfoB = Info[J];
1177      auto *InfoReal = &InfoA;
1178      auto *InfoImag = &InfoB;
1179
1180      auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181      if (!NodeFromCommon) {
1182        std::swap(InfoReal, InfoImag);
1183        NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184      }
1185      if (!NodeFromCommon)
1186        continue;
1187
1188      CommonToNode[InfoReal->Common] = NodeFromCommon;
1189      CommonToNode[InfoImag->Common] = NodeFromCommon;
1190      Processed[I] = true;
1191      Processed[J] = true;
1192    }
1193  }
1194
1195  std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196  std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197  NodePtr Result = Accumulator;
1198  for (auto &PMI : Info) {
1199    if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200      continue;
1201
1202    auto It = CommonToNode.find(PMI.Common);
1203    // TODO: Process independent complex multiplications. Cases like this:
1204    //  A.real() * B where both A and B are complex numbers.
1205    if (It == CommonToNode.end()) {
1206      LLVM_DEBUG({
1207        dbgs() << "Unprocessed independent partial multiplication:\n";
1208        for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209          dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210                           << " multiplied by " << *Mul->Multiplicand << "\n";
1211      });
1212      return nullptr;
1213    }
1214
1215    auto &RealMul = RealMuls[PMI.RealIdx];
1216    auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217
1218    auto NodeA = It->second;
1219    auto NodeB = PMI.Node;
1220    auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221    // The following table illustrates the relationship between multiplications
1222    // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223    // can see:
1224    //
1225    // Rotation |   Real |   Imag |
1226    // ---------+--------+--------+
1227    //        0 |  x * u |  x * v |
1228    //       90 | -y * v |  y * u |
1229    //      180 | -x * u | -x * v |
1230    //      270 |  y * v | -y * u |
1231    //
1232    // Check if the candidate can indeed be represented by partial
1233    // multiplication
1234    // TODO: Add support for multiplication by complex one
1235    if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236        (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237      continue;
1238
1239    // Determine the rotation based on the multiplications
1240    ComplexDeinterleavingRotation Rotation;
1241    if (IsMultiplicandReal) {
1242      // Detect 0 and 180 degrees rotation
1243      if (RealMul.IsPositive && ImagMul.IsPositive)
1244        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245      else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247      else
1248        continue;
1249
1250    } else {
1251      // Detect 90 and 270 degrees rotation
1252      if (!RealMul.IsPositive && ImagMul.IsPositive)
1253        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254      else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255        Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256      else
1257        continue;
1258    }
1259
1260    LLVM_DEBUG({
1261      dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262      dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263      dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264      dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265      dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266      dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267    });
1268
1269    NodePtr NodeMul = prepareCompositeNode(
1270        ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271    NodeMul->Rotation = Rotation;
1272    NodeMul->addOperand(NodeA);
1273    NodeMul->addOperand(NodeB);
1274    if (Result)
1275      NodeMul->addOperand(Result);
1276    submitCompositeNode(NodeMul);
1277    Result = NodeMul;
1278    ProcessedReal[PMI.RealIdx] = true;
1279    ProcessedImag[PMI.ImagIdx] = true;
1280  }
1281
1282  // Ensure all products have been processed, if not return nullptr.
1283  if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284      !all_of(ProcessedImag, [](bool V) { return V; })) {
1285
1286    // Dump debug information about which partial multiplications are not
1287    // processed.
1288    LLVM_DEBUG({
1289      dbgs() << "Unprocessed products (Real):\n";
1290      for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291        if (!ProcessedReal[i])
1292          dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293                           << *RealMuls[i].Multiplier << " multiplied by "
1294                           << *RealMuls[i].Multiplicand << "\n";
1295      }
1296      dbgs() << "Unprocessed products (Imag):\n";
1297      for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298        if (!ProcessedImag[i])
1299          dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300                           << *ImagMuls[i].Multiplier << " multiplied by "
1301                           << *ImagMuls[i].Multiplicand << "\n";
1302      }
1303    });
1304    return nullptr;
1305  }
1306
1307  return Result;
1308}
1309
1310ComplexDeinterleavingGraph::NodePtr
1311ComplexDeinterleavingGraph::identifyAdditions(
1312    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313    std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314  if (RealAddends.size() != ImagAddends.size())
1315    return nullptr;
1316
1317  NodePtr Result;
1318  // If we have accumulator use it as first addend
1319  if (Accumulator)
1320    Result = Accumulator;
1321  // Otherwise find an element with both positive real and imaginary parts.
1322  else
1323    Result = extractPositiveAddend(RealAddends, ImagAddends);
1324
1325  if (!Result)
1326    return nullptr;
1327
1328  while (!RealAddends.empty()) {
1329    auto ItR = RealAddends.begin();
1330    auto [R, IsPositiveR] = *ItR;
1331
1332    bool FoundImag = false;
1333    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334      auto [I, IsPositiveI] = *ItI;
1335      ComplexDeinterleavingRotation Rotation;
1336      if (IsPositiveR && IsPositiveI)
1337        Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338      else if (!IsPositiveR && IsPositiveI)
1339        Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340      else if (!IsPositiveR && !IsPositiveI)
1341        Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342      else
1343        Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344
1345      NodePtr AddNode;
1346      if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347          Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348        AddNode = identifyNode(R, I);
1349      } else {
1350        AddNode = identifyNode(I, R);
1351      }
1352      if (AddNode) {
1353        LLVM_DEBUG({
1354          dbgs() << "Identified addition:\n";
1355          dbgs().indent(4) << "X: " << *R << "\n";
1356          dbgs().indent(4) << "Y: " << *I << "\n";
1357          dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358        });
1359
1360        NodePtr TmpNode;
1361        if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362          TmpNode = prepareCompositeNode(
1363              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364          if (Flags) {
1365            TmpNode->Opcode = Instruction::FAdd;
1366            TmpNode->Flags = *Flags;
1367          } else {
1368            TmpNode->Opcode = Instruction::Add;
1369          }
1370        } else if (Rotation ==
1371                   llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372          TmpNode = prepareCompositeNode(
1373              ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374          if (Flags) {
1375            TmpNode->Opcode = Instruction::FSub;
1376            TmpNode->Flags = *Flags;
1377          } else {
1378            TmpNode->Opcode = Instruction::Sub;
1379          }
1380        } else {
1381          TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382                                         nullptr, nullptr);
1383          TmpNode->Rotation = Rotation;
1384        }
1385
1386        TmpNode->addOperand(Result);
1387        TmpNode->addOperand(AddNode);
1388        submitCompositeNode(TmpNode);
1389        Result = TmpNode;
1390        RealAddends.erase(ItR);
1391        ImagAddends.erase(ItI);
1392        FoundImag = true;
1393        break;
1394      }
1395    }
1396    if (!FoundImag)
1397      return nullptr;
1398  }
1399  return Result;
1400}
1401
1402ComplexDeinterleavingGraph::NodePtr
1403ComplexDeinterleavingGraph::extractPositiveAddend(
1404    std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405  for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406    for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407      auto [R, IsPositiveR] = *ItR;
1408      auto [I, IsPositiveI] = *ItI;
1409      if (IsPositiveR && IsPositiveI) {
1410        auto Result = identifyNode(R, I);
1411        if (Result) {
1412          RealAddends.erase(ItR);
1413          ImagAddends.erase(ItI);
1414          return Result;
1415        }
1416      }
1417    }
1418  }
1419  return nullptr;
1420}
1421
1422bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423  // This potential root instruction might already have been recognized as
1424  // reduction. Because RootToNode maps both Real and Imaginary parts to
1425  // CompositeNode we should choose only one either Real or Imag instruction to
1426  // use as an anchor for generating complex instruction.
1427  auto It = RootToNode.find(RootI);
1428  if (It != RootToNode.end()) {
1429    auto RootNode = It->second;
1430    assert(RootNode->Operation ==
1431           ComplexDeinterleavingOperation::ReductionOperation);
1432    // Find out which part, Real or Imag, comes later, and only if we come to
1433    // the latest part, add it to OrderedRoots.
1434    auto *R = cast<Instruction>(RootNode->Real);
1435    auto *I = cast<Instruction>(RootNode->Imag);
1436    auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437    if (ReplacementAnchor != RootI)
1438      return false;
1439    OrderedRoots.push_back(RootI);
1440    return true;
1441  }
1442
1443  auto RootNode = identifyRoot(RootI);
1444  if (!RootNode)
1445    return false;
1446
1447  LLVM_DEBUG({
1448    Function *F = RootI->getFunction();
1449    BasicBlock *B = RootI->getParent();
1450    dbgs() << "Complex deinterleaving graph for " << F->getName()
1451           << "::" << B->getName() << ".\n";
1452    dump(dbgs());
1453    dbgs() << "\n";
1454  });
1455  RootToNode[RootI] = RootNode;
1456  OrderedRoots.push_back(RootI);
1457  return true;
1458}
1459
1460bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461  bool FoundPotentialReduction = false;
1462
1463  auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464  if (!Br || Br->getNumSuccessors() != 2)
1465    return false;
1466
1467  // Identify simple one-block loop
1468  if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469    return false;
1470
1471  SmallVector<PHINode *> PHIs;
1472  for (auto &PHI : B->phis()) {
1473    if (PHI.getNumIncomingValues() != 2)
1474      continue;
1475
1476    if (!PHI.getType()->isVectorTy())
1477      continue;
1478
1479    auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480    if (!ReductionOp)
1481      continue;
1482
1483    // Check if final instruction is reduced outside of current block
1484    Instruction *FinalReduction = nullptr;
1485    auto NumUsers = 0u;
1486    for (auto *U : ReductionOp->users()) {
1487      ++NumUsers;
1488      if (U == &PHI)
1489        continue;
1490      FinalReduction = dyn_cast<Instruction>(U);
1491    }
1492
1493    if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494        isa<PHINode>(FinalReduction))
1495      continue;
1496
1497    ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498    BackEdge = B;
1499    auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500    auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501    Incoming = PHI.getIncomingBlock(IncomingIdx);
1502    FoundPotentialReduction = true;
1503
1504    // If the initial value of PHINode is an Instruction, consider it a leaf
1505    // value of a complex deinterleaving graph.
1506    if (auto *InitPHI =
1507            dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508      FinalInstructions.insert(InitPHI);
1509  }
1510  return FoundPotentialReduction;
1511}
1512
1513void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514  SmallVector<bool> Processed(ReductionInfo.size(), false);
1515  SmallVector<Instruction *> OperationInstruction;
1516  for (auto &P : ReductionInfo)
1517    OperationInstruction.push_back(P.first);
1518
1519  // Identify a complex computation by evaluating two reduction operations that
1520  // potentially could be involved
1521  for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522    if (Processed[i])
1523      continue;
1524    for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525      if (Processed[j])
1526        continue;
1527
1528      auto *Real = OperationInstruction[i];
1529      auto *Imag = OperationInstruction[j];
1530      if (Real->getType() != Imag->getType())
1531        continue;
1532
1533      RealPHI = ReductionInfo[Real].first;
1534      ImagPHI = ReductionInfo[Imag].first;
1535      PHIsFound = false;
1536      auto Node = identifyNode(Real, Imag);
1537      if (!Node) {
1538        std::swap(Real, Imag);
1539        std::swap(RealPHI, ImagPHI);
1540        Node = identifyNode(Real, Imag);
1541      }
1542
1543      // If a node is identified and reduction PHINode is used in the chain of
1544      // operations, mark its operation instructions as used to prevent
1545      // re-identification and attach the node to the real part
1546      if (Node && PHIsFound) {
1547        LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548                          << *Real << " / " << *Imag << "\n");
1549        Processed[i] = true;
1550        Processed[j] = true;
1551        auto RootNode = prepareCompositeNode(
1552            ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553        RootNode->addOperand(Node);
1554        RootToNode[Real] = RootNode;
1555        RootToNode[Imag] = RootNode;
1556        submitCompositeNode(RootNode);
1557        break;
1558      }
1559    }
1560  }
1561
1562  RealPHI = nullptr;
1563  ImagPHI = nullptr;
1564}
1565
1566bool ComplexDeinterleavingGraph::checkNodes() {
1567  // Collect all instructions from roots to leaves
1568  SmallPtrSet<Instruction *, 16> AllInstructions;
1569  SmallVector<Instruction *, 8> Worklist;
1570  for (auto &Pair : RootToNode)
1571    Worklist.push_back(Pair.first);
1572
1573  // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574  // chains
1575  while (!Worklist.empty()) {
1576    auto *I = Worklist.back();
1577    Worklist.pop_back();
1578
1579    if (!AllInstructions.insert(I).second)
1580      continue;
1581
1582    for (Value *Op : I->operands()) {
1583      if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584        if (!FinalInstructions.count(I))
1585          Worklist.emplace_back(OpI);
1586      }
1587    }
1588  }
1589
1590  // Find instructions that have users outside of chain
1591  SmallVector<Instruction *, 2> OuterInstructions;
1592  for (auto *I : AllInstructions) {
1593    // Skip root nodes
1594    if (RootToNode.count(I))
1595      continue;
1596
1597    for (User *U : I->users()) {
1598      if (AllInstructions.count(cast<Instruction>(U)))
1599        continue;
1600
1601      // Found an instruction that is not used by XCMLA/XCADD chain
1602      Worklist.emplace_back(I);
1603      break;
1604    }
1605  }
1606
1607  // If any instructions are found to be used outside, find and remove roots
1608  // that somehow connect to those instructions.
1609  SmallPtrSet<Instruction *, 16> Visited;
1610  while (!Worklist.empty()) {
1611    auto *I = Worklist.back();
1612    Worklist.pop_back();
1613    if (!Visited.insert(I).second)
1614      continue;
1615
1616    // Found an impacted root node. Removing it from the nodes to be
1617    // deinterleaved
1618    if (RootToNode.count(I)) {
1619      LLVM_DEBUG(dbgs() << "Instruction " << *I
1620                        << " could be deinterleaved but its chain of complex "
1621                           "operations have an outside user\n");
1622      RootToNode.erase(I);
1623    }
1624
1625    if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626      continue;
1627
1628    for (User *U : I->users())
1629      Worklist.emplace_back(cast<Instruction>(U));
1630
1631    for (Value *Op : I->operands()) {
1632      if (auto *OpI = dyn_cast<Instruction>(Op))
1633        Worklist.emplace_back(OpI);
1634    }
1635  }
1636  return !RootToNode.empty();
1637}
1638
1639ComplexDeinterleavingGraph::NodePtr
1640ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641  if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642    if (Intrinsic->getIntrinsicID() !=
1643        Intrinsic::experimental_vector_interleave2)
1644      return nullptr;
1645
1646    auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1647    auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1648    if (!Real || !Imag)
1649      return nullptr;
1650
1651    return identifyNode(Real, Imag);
1652  }
1653
1654  auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1655  if (!SVI)
1656    return nullptr;
1657
1658  // Look for a shufflevector that takes separate vectors of the real and
1659  // imaginary components and recombines them into a single vector.
1660  if (!isInterleavingMask(SVI->getShuffleMask()))
1661    return nullptr;
1662
1663  Instruction *Real;
1664  Instruction *Imag;
1665  if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1666    return nullptr;
1667
1668  return identifyNode(Real, Imag);
1669}
1670
1671ComplexDeinterleavingGraph::NodePtr
1672ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1673                                                 Instruction *Imag) {
1674  Instruction *I = nullptr;
1675  Value *FinalValue = nullptr;
1676  if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1677      match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1678      match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1679                   m_Value(FinalValue)))) {
1680    NodePtr PlaceholderNode = prepareCompositeNode(
1681        llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1682    PlaceholderNode->ReplacementNode = FinalValue;
1683    FinalInstructions.insert(Real);
1684    FinalInstructions.insert(Imag);
1685    return submitCompositeNode(PlaceholderNode);
1686  }
1687
1688  auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1689  auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1690  if (!RealShuffle || !ImagShuffle) {
1691    if (RealShuffle || ImagShuffle)
1692      LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1693    return nullptr;
1694  }
1695
1696  Value *RealOp1 = RealShuffle->getOperand(1);
1697  if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1698    LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1699    return nullptr;
1700  }
1701  Value *ImagOp1 = ImagShuffle->getOperand(1);
1702  if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1703    LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1704    return nullptr;
1705  }
1706
1707  Value *RealOp0 = RealShuffle->getOperand(0);
1708  Value *ImagOp0 = ImagShuffle->getOperand(0);
1709
1710  if (RealOp0 != ImagOp0) {
1711    LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1712    return nullptr;
1713  }
1714
1715  ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1716  ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1717  if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1718    LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1719    return nullptr;
1720  }
1721
1722  if (RealMask[0] != 0 || ImagMask[0] != 1) {
1723    LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1724    return nullptr;
1725  }
1726
1727  // Type checking, the shuffle type should be a vector type of the same
1728  // scalar type, but half the size
1729  auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1730    Value *Op = Shuffle->getOperand(0);
1731    auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1732    auto *OpTy = cast<FixedVectorType>(Op->getType());
1733
1734    if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1735      return false;
1736    if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1737      return false;
1738
1739    return true;
1740  };
1741
1742  auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1743    if (!CheckType(Shuffle))
1744      return false;
1745
1746    ArrayRef<int> Mask = Shuffle->getShuffleMask();
1747    int Last = *Mask.rbegin();
1748
1749    Value *Op = Shuffle->getOperand(0);
1750    auto *OpTy = cast<FixedVectorType>(Op->getType());
1751    int NumElements = OpTy->getNumElements();
1752
1753    // Ensure that the deinterleaving shuffle only pulls from the first
1754    // shuffle operand.
1755    return Last < NumElements;
1756  };
1757
1758  if (RealShuffle->getType() != ImagShuffle->getType()) {
1759    LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1760    return nullptr;
1761  }
1762  if (!CheckDeinterleavingShuffle(RealShuffle)) {
1763    LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1764    return nullptr;
1765  }
1766  if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1767    LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1768    return nullptr;
1769  }
1770
1771  NodePtr PlaceholderNode =
1772      prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1773                           RealShuffle, ImagShuffle);
1774  PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1775  FinalInstructions.insert(RealShuffle);
1776  FinalInstructions.insert(ImagShuffle);
1777  return submitCompositeNode(PlaceholderNode);
1778}
1779
1780ComplexDeinterleavingGraph::NodePtr
1781ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1782  auto IsSplat = [](Value *V) -> bool {
1783    // Fixed-width vector with constants
1784    if (isa<ConstantDataVector>(V))
1785      return true;
1786
1787    VectorType *VTy;
1788    ArrayRef<int> Mask;
1789    // Splats are represented differently depending on whether the repeated
1790    // value is a constant or an Instruction
1791    if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1792      if (Const->getOpcode() != Instruction::ShuffleVector)
1793        return false;
1794      VTy = cast<VectorType>(Const->getType());
1795      Mask = Const->getShuffleMask();
1796    } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1797      VTy = Shuf->getType();
1798      Mask = Shuf->getShuffleMask();
1799    } else {
1800      return false;
1801    }
1802
1803    // When the data type is <1 x Type>, it's not possible to differentiate
1804    // between the ComplexDeinterleaving::Deinterleave and
1805    // ComplexDeinterleaving::Splat operations.
1806    if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1807      return false;
1808
1809    return all_equal(Mask) && Mask[0] == 0;
1810  };
1811
1812  if (!IsSplat(R) || !IsSplat(I))
1813    return nullptr;
1814
1815  auto *Real = dyn_cast<Instruction>(R);
1816  auto *Imag = dyn_cast<Instruction>(I);
1817  if ((!Real && Imag) || (Real && !Imag))
1818    return nullptr;
1819
1820  if (Real && Imag) {
1821    // Non-constant splats should be in the same basic block
1822    if (Real->getParent() != Imag->getParent())
1823      return nullptr;
1824
1825    FinalInstructions.insert(Real);
1826    FinalInstructions.insert(Imag);
1827  }
1828  NodePtr PlaceholderNode =
1829      prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1830  return submitCompositeNode(PlaceholderNode);
1831}
1832
1833ComplexDeinterleavingGraph::NodePtr
1834ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1835                                            Instruction *Imag) {
1836  if (Real != RealPHI || Imag != ImagPHI)
1837    return nullptr;
1838
1839  PHIsFound = true;
1840  NodePtr PlaceholderNode = prepareCompositeNode(
1841      ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1842  return submitCompositeNode(PlaceholderNode);
1843}
1844
1845ComplexDeinterleavingGraph::NodePtr
1846ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1847                                               Instruction *Imag) {
1848  auto *SelectReal = dyn_cast<SelectInst>(Real);
1849  auto *SelectImag = dyn_cast<SelectInst>(Imag);
1850  if (!SelectReal || !SelectImag)
1851    return nullptr;
1852
1853  Instruction *MaskA, *MaskB;
1854  Instruction *AR, *AI, *RA, *BI;
1855  if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1856                            m_Instruction(RA))) ||
1857      !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1858                            m_Instruction(BI))))
1859    return nullptr;
1860
1861  if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1862    return nullptr;
1863
1864  if (!MaskA->getType()->isVectorTy())
1865    return nullptr;
1866
1867  auto NodeA = identifyNode(AR, AI);
1868  if (!NodeA)
1869    return nullptr;
1870
1871  auto NodeB = identifyNode(RA, BI);
1872  if (!NodeB)
1873    return nullptr;
1874
1875  NodePtr PlaceholderNode = prepareCompositeNode(
1876      ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1877  PlaceholderNode->addOperand(NodeA);
1878  PlaceholderNode->addOperand(NodeB);
1879  FinalInstructions.insert(MaskA);
1880  FinalInstructions.insert(MaskB);
1881  return submitCompositeNode(PlaceholderNode);
1882}
1883
1884static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1885                                   std::optional<FastMathFlags> Flags,
1886                                   Value *InputA, Value *InputB) {
1887  Value *I;
1888  switch (Opcode) {
1889  case Instruction::FNeg:
1890    I = B.CreateFNeg(InputA);
1891    break;
1892  case Instruction::FAdd:
1893    I = B.CreateFAdd(InputA, InputB);
1894    break;
1895  case Instruction::Add:
1896    I = B.CreateAdd(InputA, InputB);
1897    break;
1898  case Instruction::FSub:
1899    I = B.CreateFSub(InputA, InputB);
1900    break;
1901  case Instruction::Sub:
1902    I = B.CreateSub(InputA, InputB);
1903    break;
1904  case Instruction::FMul:
1905    I = B.CreateFMul(InputA, InputB);
1906    break;
1907  case Instruction::Mul:
1908    I = B.CreateMul(InputA, InputB);
1909    break;
1910  default:
1911    llvm_unreachable("Incorrect symmetric opcode");
1912  }
1913  if (Flags)
1914    cast<Instruction>(I)->setFastMathFlags(*Flags);
1915  return I;
1916}
1917
1918Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1919                                               RawNodePtr Node) {
1920  if (Node->ReplacementNode)
1921    return Node->ReplacementNode;
1922
1923  auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1924    return Node->Operands.size() > Idx
1925               ? replaceNode(Builder, Node->Operands[Idx])
1926               : nullptr;
1927  };
1928
1929  Value *ReplacementNode;
1930  switch (Node->Operation) {
1931  case ComplexDeinterleavingOperation::CAdd:
1932  case ComplexDeinterleavingOperation::CMulPartial:
1933  case ComplexDeinterleavingOperation::Symmetric: {
1934    Value *Input0 = ReplaceOperandIfExist(Node, 0);
1935    Value *Input1 = ReplaceOperandIfExist(Node, 1);
1936    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1937    assert(!Input1 || (Input0->getType() == Input1->getType() &&
1938                       "Node inputs need to be of the same type"));
1939    assert(!Accumulator ||
1940           (Input0->getType() == Accumulator->getType() &&
1941            "Accumulator and input need to be of the same type"));
1942    if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1943      ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1944                                             Input0, Input1);
1945    else
1946      ReplacementNode = TL->createComplexDeinterleavingIR(
1947          Builder, Node->Operation, Node->Rotation, Input0, Input1,
1948          Accumulator);
1949    break;
1950  }
1951  case ComplexDeinterleavingOperation::Deinterleave:
1952    llvm_unreachable("Deinterleave node should already have ReplacementNode");
1953    break;
1954  case ComplexDeinterleavingOperation::Splat: {
1955    auto *NewTy = VectorType::getDoubleElementsVectorType(
1956        cast<VectorType>(Node->Real->getType()));
1957    auto *R = dyn_cast<Instruction>(Node->Real);
1958    auto *I = dyn_cast<Instruction>(Node->Imag);
1959    if (R && I) {
1960      // Splats that are not constant are interleaved where they are located
1961      Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1962      IRBuilder<> IRB(InsertPoint);
1963      ReplacementNode =
1964          IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1965                              {Node->Real, Node->Imag});
1966    } else {
1967      ReplacementNode =
1968          Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1969                                  NewTy, {Node->Real, Node->Imag});
1970    }
1971    break;
1972  }
1973  case ComplexDeinterleavingOperation::ReductionPHI: {
1974    // If Operation is ReductionPHI, a new empty PHINode is created.
1975    // It is filled later when the ReductionOperation is processed.
1976    auto *VTy = cast<VectorType>(Node->Real->getType());
1977    auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1978    auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1979    OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1980    ReplacementNode = NewPHI;
1981    break;
1982  }
1983  case ComplexDeinterleavingOperation::ReductionOperation:
1984    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1985    processReductionOperation(ReplacementNode, Node);
1986    break;
1987  case ComplexDeinterleavingOperation::ReductionSelect: {
1988    auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1989    auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1990    auto *A = replaceNode(Builder, Node->Operands[0]);
1991    auto *B = replaceNode(Builder, Node->Operands[1]);
1992    auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1993        cast<VectorType>(MaskReal->getType()));
1994    auto *NewMask =
1995        Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1996                                NewMaskTy, {MaskReal, MaskImag});
1997    ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1998    break;
1999  }
2000  }
2001
2002  assert(ReplacementNode && "Target failed to create Intrinsic call.");
2003  NumComplexTransformations += 1;
2004  Node->ReplacementNode = ReplacementNode;
2005  return ReplacementNode;
2006}
2007
2008void ComplexDeinterleavingGraph::processReductionOperation(
2009    Value *OperationReplacement, RawNodePtr Node) {
2010  auto *Real = cast<Instruction>(Node->Real);
2011  auto *Imag = cast<Instruction>(Node->Imag);
2012  auto *OldPHIReal = ReductionInfo[Real].first;
2013  auto *OldPHIImag = ReductionInfo[Imag].first;
2014  auto *NewPHI = OldToNewPHI[OldPHIReal];
2015
2016  auto *VTy = cast<VectorType>(Real->getType());
2017  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2018
2019  // We have to interleave initial origin values coming from IncomingBlock
2020  Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2021  Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2022
2023  IRBuilder<> Builder(Incoming->getTerminator());
2024  auto *NewInit = Builder.CreateIntrinsic(
2025      Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2026
2027  NewPHI->addIncoming(NewInit, Incoming);
2028  NewPHI->addIncoming(OperationReplacement, BackEdge);
2029
2030  // Deinterleave complex vector outside of loop so that it can be finally
2031  // reduced
2032  auto *FinalReductionReal = ReductionInfo[Real].second;
2033  auto *FinalReductionImag = ReductionInfo[Imag].second;
2034
2035  Builder.SetInsertPoint(
2036      &*FinalReductionReal->getParent()->getFirstInsertionPt());
2037  auto *Deinterleave = Builder.CreateIntrinsic(
2038      Intrinsic::experimental_vector_deinterleave2,
2039      OperationReplacement->getType(), OperationReplacement);
2040
2041  auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2042  FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2043
2044  Builder.SetInsertPoint(FinalReductionImag);
2045  auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2046  FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2047}
2048
2049void ComplexDeinterleavingGraph::replaceNodes() {
2050  SmallVector<Instruction *, 16> DeadInstrRoots;
2051  for (auto *RootInstruction : OrderedRoots) {
2052    // Check if this potential root went through check process and we can
2053    // deinterleave it
2054    if (!RootToNode.count(RootInstruction))
2055      continue;
2056
2057    IRBuilder<> Builder(RootInstruction);
2058    auto RootNode = RootToNode[RootInstruction];
2059    Value *R = replaceNode(Builder, RootNode.get());
2060
2061    if (RootNode->Operation ==
2062        ComplexDeinterleavingOperation::ReductionOperation) {
2063      auto *RootReal = cast<Instruction>(RootNode->Real);
2064      auto *RootImag = cast<Instruction>(RootNode->Imag);
2065      ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2066      ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2067      DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2068      DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2069    } else {
2070      assert(R && "Unable to find replacement for RootInstruction");
2071      DeadInstrRoots.push_back(RootInstruction);
2072      RootInstruction->replaceAllUsesWith(R);
2073    }
2074  }
2075
2076  for (auto *I : DeadInstrRoots)
2077    RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2078}
2079