X86CondBrFolding.cpp revision 360784
1//===---- X86CondBrFolding.cpp - optimize conditional branches ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This file defines a pass that optimizes condition branches on x86 by taking
9// advantage of the three-way conditional code generated by compare
10// instructions.
11// Currently, it tries to hoisting EQ and NE conditional branch to a dominant
12// conditional branch condition where the same EQ/NE conditional code is
13// computed. An example:
14//   bb_0:
15//     cmp %0, 19
16//     jg bb_1
17//     jmp bb_2
18//   bb_1:
19//     cmp %0, 40
20//     jg bb_3
21//     jmp bb_4
22//   bb_4:
23//     cmp %0, 20
24//     je bb_5
25//     jmp bb_6
26// Here we could combine the two compares in bb_0 and bb_4 and have the
27// following code:
28//   bb_0:
29//     cmp %0, 20
30//     jg bb_1
31//     jl bb_2
32//     jmp bb_5
33//   bb_1:
34//     cmp %0, 40
35//     jg bb_3
36//     jmp bb_6
37// For the case of %0 == 20 (bb_5), we eliminate two jumps, and the control
38// height for bb_6 is also reduced. bb_4 is gone after the optimization.
39//
40// There are plenty of this code patterns, especially from the switch case
41// lowing where we generate compare of "pivot-1" for the inner nodes in the
42// binary search tree.
43//===----------------------------------------------------------------------===//
44
45#include "X86.h"
46#include "X86InstrInfo.h"
47#include "X86Subtarget.h"
48#include "llvm/ADT/Statistic.h"
49#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
50#include "llvm/CodeGen/MachineFunctionPass.h"
51#include "llvm/CodeGen/MachineInstrBuilder.h"
52#include "llvm/CodeGen/MachineRegisterInfo.h"
53#include "llvm/Support/BranchProbability.h"
54
55using namespace llvm;
56
57#define DEBUG_TYPE "x86-condbr-folding"
58
59STATISTIC(NumFixedCondBrs, "Number of x86 condbr folded");
60
61namespace {
62class X86CondBrFoldingPass : public MachineFunctionPass {
63public:
64  X86CondBrFoldingPass() : MachineFunctionPass(ID) { }
65  StringRef getPassName() const override { return "X86 CondBr Folding"; }
66
67  bool runOnMachineFunction(MachineFunction &MF) override;
68
69  void getAnalysisUsage(AnalysisUsage &AU) const override {
70    MachineFunctionPass::getAnalysisUsage(AU);
71    AU.addRequired<MachineBranchProbabilityInfo>();
72  }
73
74public:
75  static char ID;
76};
77} // namespace
78
79char X86CondBrFoldingPass::ID = 0;
80INITIALIZE_PASS(X86CondBrFoldingPass, "X86CondBrFolding", "X86CondBrFolding", false, false)
81
82FunctionPass *llvm::createX86CondBrFolding() {
83  return new X86CondBrFoldingPass();
84}
85
86namespace {
87// A class the stores the auxiliary information for each MBB.
88struct TargetMBBInfo {
89  MachineBasicBlock *TBB;
90  MachineBasicBlock *FBB;
91  MachineInstr *BrInstr;
92  MachineInstr *CmpInstr;
93  X86::CondCode BranchCode;
94  unsigned SrcReg;
95  int CmpValue;
96  bool Modified;
97  bool CmpBrOnly;
98};
99
100// A class that optimizes the conditional branch by hoisting and merge CondCode.
101class X86CondBrFolding {
102public:
103  X86CondBrFolding(const X86InstrInfo *TII,
104                   const MachineBranchProbabilityInfo *MBPI,
105                   MachineFunction &MF)
106      : TII(TII), MBPI(MBPI), MF(MF) {}
107  bool optimize();
108
109private:
110  const X86InstrInfo *TII;
111  const MachineBranchProbabilityInfo *MBPI;
112  MachineFunction &MF;
113  std::vector<std::unique_ptr<TargetMBBInfo>> MBBInfos;
114  SmallVector<MachineBasicBlock *, 4> RemoveList;
115
116  void optimizeCondBr(MachineBasicBlock &MBB,
117                      SmallVectorImpl<MachineBasicBlock *> &BranchPath);
118  void replaceBrDest(MachineBasicBlock *MBB, MachineBasicBlock *OrigDest,
119                     MachineBasicBlock *NewDest);
120  void fixupModifiedCond(MachineBasicBlock *MBB);
121  std::unique_ptr<TargetMBBInfo> analyzeMBB(MachineBasicBlock &MBB);
122  static bool analyzeCompare(const MachineInstr &MI, unsigned &SrcReg,
123                             int &CmpValue);
124  bool findPath(MachineBasicBlock *MBB,
125                SmallVectorImpl<MachineBasicBlock *> &BranchPath);
126  TargetMBBInfo *getMBBInfo(MachineBasicBlock *MBB) const {
127    return MBBInfos[MBB->getNumber()].get();
128  }
129};
130} // namespace
131
132// Find a valid path that we can reuse the CondCode.
133// The resulted path (if return true) is stored in BranchPath.
134// Return value:
135//  false: is no valid path is found.
136//  true: a valid path is found and the targetBB can be reached.
137bool X86CondBrFolding::findPath(
138    MachineBasicBlock *MBB, SmallVectorImpl<MachineBasicBlock *> &BranchPath) {
139  TargetMBBInfo *MBBInfo = getMBBInfo(MBB);
140  assert(MBBInfo && "Expecting a candidate MBB");
141  int CmpValue = MBBInfo->CmpValue;
142
143  MachineBasicBlock *PredMBB = *MBB->pred_begin();
144  MachineBasicBlock *SaveMBB = MBB;
145  while (PredMBB) {
146    TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB);
147    if (!PredMBBInfo || PredMBBInfo->SrcReg != MBBInfo->SrcReg)
148      return false;
149
150    assert(SaveMBB == PredMBBInfo->TBB || SaveMBB == PredMBBInfo->FBB);
151    bool IsFalseBranch = (SaveMBB == PredMBBInfo->FBB);
152
153    X86::CondCode CC = PredMBBInfo->BranchCode;
154    assert(CC == X86::COND_L || CC == X86::COND_G || CC == X86::COND_E);
155    int PredCmpValue = PredMBBInfo->CmpValue;
156    bool ValueCmpTrue = ((CmpValue < PredCmpValue && CC == X86::COND_L) ||
157                         (CmpValue > PredCmpValue && CC == X86::COND_G) ||
158                         (CmpValue == PredCmpValue && CC == X86::COND_E));
159    // Check if both the result of value compare and the branch target match.
160    if (!(ValueCmpTrue ^ IsFalseBranch)) {
161      LLVM_DEBUG(dbgs() << "Dead BB detected!\n");
162      return false;
163    }
164
165    BranchPath.push_back(PredMBB);
166    // These are the conditions on which we could combine the compares.
167    if ((CmpValue == PredCmpValue) ||
168        (CmpValue == PredCmpValue - 1 && CC == X86::COND_L) ||
169        (CmpValue == PredCmpValue + 1 && CC == X86::COND_G))
170      return true;
171
172    // If PredMBB has more than on preds, or not a pure cmp and br, we bailout.
173    if (PredMBB->pred_size() != 1 || !PredMBBInfo->CmpBrOnly)
174      return false;
175
176    SaveMBB = PredMBB;
177    PredMBB = *PredMBB->pred_begin();
178  }
179  return false;
180}
181
182// Fix up any PHI node in the successor of MBB.
183static void fixPHIsInSucc(MachineBasicBlock *MBB, MachineBasicBlock *OldMBB,
184                          MachineBasicBlock *NewMBB) {
185  if (NewMBB == OldMBB)
186    return;
187  for (auto MI = MBB->instr_begin(), ME = MBB->instr_end();
188       MI != ME && MI->isPHI(); ++MI)
189    for (unsigned i = 2, e = MI->getNumOperands() + 1; i != e; i += 2) {
190      MachineOperand &MO = MI->getOperand(i);
191      if (MO.getMBB() == OldMBB)
192        MO.setMBB(NewMBB);
193    }
194}
195
196// Utility function to set branch probability for edge MBB->SuccMBB.
197static inline bool setBranchProb(MachineBasicBlock *MBB,
198                                 MachineBasicBlock *SuccMBB,
199                                 BranchProbability Prob) {
200  auto MBBI = std::find(MBB->succ_begin(), MBB->succ_end(), SuccMBB);
201  if (MBBI == MBB->succ_end())
202    return false;
203  MBB->setSuccProbability(MBBI, Prob);
204  return true;
205}
206
207// Utility function to find the unconditional br instruction in MBB.
208static inline MachineBasicBlock::iterator
209findUncondBrI(MachineBasicBlock *MBB) {
210  return std::find_if(MBB->begin(), MBB->end(), [](MachineInstr &MI) -> bool {
211    return MI.getOpcode() == X86::JMP_1;
212  });
213}
214
215// Replace MBB's original successor, OrigDest, with NewDest.
216// Also update the MBBInfo for MBB.
217void X86CondBrFolding::replaceBrDest(MachineBasicBlock *MBB,
218                                     MachineBasicBlock *OrigDest,
219                                     MachineBasicBlock *NewDest) {
220  TargetMBBInfo *MBBInfo = getMBBInfo(MBB);
221  MachineInstr *BrMI;
222  if (MBBInfo->TBB == OrigDest) {
223    BrMI = MBBInfo->BrInstr;
224    MachineInstrBuilder MIB =
225        BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI), TII->get(X86::JCC_1))
226            .addMBB(NewDest).addImm(MBBInfo->BranchCode);
227    MBBInfo->TBB = NewDest;
228    MBBInfo->BrInstr = MIB.getInstr();
229  } else { // Should be the unconditional jump stmt.
230    MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB);
231    BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1))
232        .addMBB(NewDest);
233    MBBInfo->FBB = NewDest;
234    BrMI = &*UncondBrI;
235  }
236  fixPHIsInSucc(NewDest, OrigDest, MBB);
237  BrMI->eraseFromParent();
238  MBB->addSuccessor(NewDest);
239  setBranchProb(MBB, NewDest, MBPI->getEdgeProbability(MBB, OrigDest));
240  MBB->removeSuccessor(OrigDest);
241}
242
243// Change the CondCode and BrInstr according to MBBInfo.
244void X86CondBrFolding::fixupModifiedCond(MachineBasicBlock *MBB) {
245  TargetMBBInfo *MBBInfo = getMBBInfo(MBB);
246  if (!MBBInfo->Modified)
247    return;
248
249  MachineInstr *BrMI = MBBInfo->BrInstr;
250  X86::CondCode CC = MBBInfo->BranchCode;
251  MachineInstrBuilder MIB = BuildMI(*MBB, BrMI, MBB->findDebugLoc(BrMI),
252                                    TII->get(X86::JCC_1))
253                                .addMBB(MBBInfo->TBB).addImm(CC);
254  BrMI->eraseFromParent();
255  MBBInfo->BrInstr = MIB.getInstr();
256
257  MachineBasicBlock::iterator UncondBrI = findUncondBrI(MBB);
258  BuildMI(*MBB, UncondBrI, MBB->findDebugLoc(UncondBrI), TII->get(X86::JMP_1))
259      .addMBB(MBBInfo->FBB);
260  MBB->erase(UncondBrI);
261  MBBInfo->Modified = false;
262}
263
264//
265// Apply the transformation:
266//  RootMBB -1-> ... PredMBB -3-> MBB -5-> TargetMBB
267//     \-2->           \-4->       \-6-> FalseMBB
268// ==>
269//             RootMBB -1-> ... PredMBB -7-> FalseMBB
270// TargetMBB <-8-/ \-2->           \-4->
271//
272// Note that PredMBB and RootMBB could be the same.
273// And in the case of dead TargetMBB, we will not have TargetMBB and edge 8.
274//
275// There are some special handling where the RootMBB is COND_E in which case
276// we directly short-cycle the brinstr.
277//
278void X86CondBrFolding::optimizeCondBr(
279    MachineBasicBlock &MBB, SmallVectorImpl<MachineBasicBlock *> &BranchPath) {
280
281  X86::CondCode CC;
282  TargetMBBInfo *MBBInfo = getMBBInfo(&MBB);
283  assert(MBBInfo && "Expecting a candidate MBB");
284  MachineBasicBlock *TargetMBB = MBBInfo->TBB;
285  BranchProbability TargetProb = MBPI->getEdgeProbability(&MBB, MBBInfo->TBB);
286
287  // Forward the jump from MBB's predecessor to MBB's false target.
288  MachineBasicBlock *PredMBB = BranchPath.front();
289  TargetMBBInfo *PredMBBInfo = getMBBInfo(PredMBB);
290  assert(PredMBBInfo && "Expecting a candidate MBB");
291  if (PredMBBInfo->Modified)
292    fixupModifiedCond(PredMBB);
293  CC = PredMBBInfo->BranchCode;
294  // Don't do this if depth of BranchPath is 1 and PredMBB is of COND_E.
295  // We will short-cycle directly for this case.
296  if (!(CC == X86::COND_E && BranchPath.size() == 1))
297    replaceBrDest(PredMBB, &MBB, MBBInfo->FBB);
298
299  MachineBasicBlock *RootMBB = BranchPath.back();
300  TargetMBBInfo *RootMBBInfo = getMBBInfo(RootMBB);
301  assert(RootMBBInfo && "Expecting a candidate MBB");
302  if (RootMBBInfo->Modified)
303    fixupModifiedCond(RootMBB);
304  CC = RootMBBInfo->BranchCode;
305
306  if (CC != X86::COND_E) {
307    MachineBasicBlock::iterator UncondBrI = findUncondBrI(RootMBB);
308    // RootMBB: Cond jump to the original not-taken MBB.
309    X86::CondCode NewCC;
310    switch (CC) {
311    case X86::COND_L:
312      NewCC = X86::COND_G;
313      break;
314    case X86::COND_G:
315      NewCC = X86::COND_L;
316      break;
317    default:
318      llvm_unreachable("unexpected condtional code.");
319    }
320    BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI),
321            TII->get(X86::JCC_1))
322        .addMBB(RootMBBInfo->FBB).addImm(NewCC);
323
324    // RootMBB: Jump to TargetMBB
325    BuildMI(*RootMBB, UncondBrI, RootMBB->findDebugLoc(UncondBrI),
326            TII->get(X86::JMP_1))
327        .addMBB(TargetMBB);
328    RootMBB->addSuccessor(TargetMBB);
329    fixPHIsInSucc(TargetMBB, &MBB, RootMBB);
330    RootMBB->erase(UncondBrI);
331  } else {
332    replaceBrDest(RootMBB, RootMBBInfo->TBB, TargetMBB);
333  }
334
335  // Fix RootMBB's CmpValue to MBB's CmpValue to TargetMBB. Don't set Imm
336  // directly. Move MBB's stmt to here as the opcode might be different.
337  if (RootMBBInfo->CmpValue != MBBInfo->CmpValue) {
338    MachineInstr *NewCmp = MBBInfo->CmpInstr;
339    NewCmp->removeFromParent();
340    RootMBB->insert(RootMBBInfo->CmpInstr, NewCmp);
341    RootMBBInfo->CmpInstr->eraseFromParent();
342  }
343
344  // Fix branch Probabilities.
345  auto fixBranchProb = [&](MachineBasicBlock *NextMBB) {
346    BranchProbability Prob;
347    for (auto &I : BranchPath) {
348      MachineBasicBlock *ThisMBB = I;
349      if (!ThisMBB->hasSuccessorProbabilities() ||
350          !ThisMBB->isSuccessor(NextMBB))
351        break;
352      Prob = MBPI->getEdgeProbability(ThisMBB, NextMBB);
353      if (Prob.isUnknown())
354        break;
355      TargetProb = Prob * TargetProb;
356      Prob = Prob - TargetProb;
357      setBranchProb(ThisMBB, NextMBB, Prob);
358      if (ThisMBB == RootMBB) {
359        setBranchProb(ThisMBB, TargetMBB, TargetProb);
360      }
361      ThisMBB->normalizeSuccProbs();
362      if (ThisMBB == RootMBB)
363        break;
364      NextMBB = ThisMBB;
365    }
366    return true;
367  };
368  if (CC != X86::COND_E && !TargetProb.isUnknown())
369    fixBranchProb(MBBInfo->FBB);
370
371  if (CC != X86::COND_E)
372    RemoveList.push_back(&MBB);
373
374  // Invalidate MBBInfo just in case.
375  MBBInfos[MBB.getNumber()] = nullptr;
376  MBBInfos[RootMBB->getNumber()] = nullptr;
377
378  LLVM_DEBUG(dbgs() << "After optimization:\nRootMBB is: " << *RootMBB << "\n");
379  if (BranchPath.size() > 1)
380    LLVM_DEBUG(dbgs() << "PredMBB is: " << *(BranchPath[0]) << "\n");
381}
382
383// Driver function for optimization: find the valid candidate and apply
384// the transformation.
385bool X86CondBrFolding::optimize() {
386  bool Changed = false;
387  LLVM_DEBUG(dbgs() << "***** X86CondBr Folding on Function: " << MF.getName()
388                    << " *****\n");
389  // Setup data structures.
390  MBBInfos.resize(MF.getNumBlockIDs());
391  for (auto &MBB : MF)
392    MBBInfos[MBB.getNumber()] = analyzeMBB(MBB);
393
394  for (auto &MBB : MF) {
395    TargetMBBInfo *MBBInfo = getMBBInfo(&MBB);
396    if (!MBBInfo || !MBBInfo->CmpBrOnly)
397      continue;
398    if (MBB.pred_size() != 1)
399      continue;
400    LLVM_DEBUG(dbgs() << "Work on MBB." << MBB.getNumber()
401                      << " CmpValue: " << MBBInfo->CmpValue << "\n");
402    SmallVector<MachineBasicBlock *, 4> BranchPath;
403    if (!findPath(&MBB, BranchPath))
404      continue;
405
406#ifndef NDEBUG
407    LLVM_DEBUG(dbgs() << "Found one path (len=" << BranchPath.size() << "):\n");
408    int Index = 1;
409    LLVM_DEBUG(dbgs() << "Target MBB is: " << MBB << "\n");
410    for (auto I = BranchPath.rbegin(); I != BranchPath.rend(); ++I, ++Index) {
411      MachineBasicBlock *PMBB = *I;
412      TargetMBBInfo *PMBBInfo = getMBBInfo(PMBB);
413      LLVM_DEBUG(dbgs() << "Path MBB (" << Index << " of " << BranchPath.size()
414                        << ") is " << *PMBB);
415      LLVM_DEBUG(dbgs() << "CC=" << PMBBInfo->BranchCode
416                        << "  Val=" << PMBBInfo->CmpValue
417                        << "  CmpBrOnly=" << PMBBInfo->CmpBrOnly << "\n\n");
418    }
419#endif
420    optimizeCondBr(MBB, BranchPath);
421    Changed = true;
422  }
423  NumFixedCondBrs += RemoveList.size();
424  for (auto MBBI : RemoveList) {
425    while (!MBBI->succ_empty())
426      MBBI->removeSuccessor(MBBI->succ_end() - 1);
427
428    MBBI->eraseFromParent();
429  }
430
431  return Changed;
432}
433
434// Analyze instructions that generate CondCode and extract information.
435bool X86CondBrFolding::analyzeCompare(const MachineInstr &MI, unsigned &SrcReg,
436                                      int &CmpValue) {
437  unsigned SrcRegIndex = 0;
438  unsigned ValueIndex = 0;
439  switch (MI.getOpcode()) {
440  // TODO: handle test instructions.
441  default:
442    return false;
443  case X86::CMP64ri32:
444  case X86::CMP64ri8:
445  case X86::CMP32ri:
446  case X86::CMP32ri8:
447  case X86::CMP16ri:
448  case X86::CMP16ri8:
449  case X86::CMP8ri:
450    SrcRegIndex = 0;
451    ValueIndex = 1;
452    break;
453  case X86::SUB64ri32:
454  case X86::SUB64ri8:
455  case X86::SUB32ri:
456  case X86::SUB32ri8:
457  case X86::SUB16ri:
458  case X86::SUB16ri8:
459  case X86::SUB8ri:
460    SrcRegIndex = 1;
461    ValueIndex = 2;
462    break;
463  }
464  SrcReg = MI.getOperand(SrcRegIndex).getReg();
465  if (!MI.getOperand(ValueIndex).isImm())
466    return false;
467  CmpValue = MI.getOperand(ValueIndex).getImm();
468  return true;
469}
470
471// Analyze a candidate MBB and set the extract all the information needed.
472// The valid candidate will have two successors.
473// It also should have a sequence of
474//  Branch_instr,
475//  CondBr,
476//  UnCondBr.
477// Return TargetMBBInfo if MBB is a valid candidate and nullptr otherwise.
478std::unique_ptr<TargetMBBInfo>
479X86CondBrFolding::analyzeMBB(MachineBasicBlock &MBB) {
480  MachineBasicBlock *TBB;
481  MachineBasicBlock *FBB;
482  MachineInstr *BrInstr;
483  MachineInstr *CmpInstr;
484  X86::CondCode CC;
485  unsigned SrcReg;
486  int CmpValue;
487  bool Modified;
488  bool CmpBrOnly;
489
490  if (MBB.succ_size() != 2)
491    return nullptr;
492
493  CmpBrOnly = true;
494  FBB = TBB = nullptr;
495  CmpInstr = nullptr;
496  MachineBasicBlock::iterator I = MBB.end();
497  while (I != MBB.begin()) {
498    --I;
499    if (I->isDebugValue())
500      continue;
501    if (I->getOpcode() == X86::JMP_1) {
502      if (FBB)
503        return nullptr;
504      FBB = I->getOperand(0).getMBB();
505      continue;
506    }
507    if (I->isBranch()) {
508      if (TBB)
509        return nullptr;
510      CC = X86::getCondFromBranch(*I);
511      switch (CC) {
512      default:
513        return nullptr;
514      case X86::COND_E:
515      case X86::COND_L:
516      case X86::COND_G:
517      case X86::COND_NE:
518      case X86::COND_LE:
519      case X86::COND_GE:
520        break;
521      }
522      TBB = I->getOperand(0).getMBB();
523      BrInstr = &*I;
524      continue;
525    }
526    if (analyzeCompare(*I, SrcReg, CmpValue)) {
527      if (CmpInstr)
528        return nullptr;
529      CmpInstr = &*I;
530      continue;
531    }
532    CmpBrOnly = false;
533    break;
534  }
535
536  if (!TBB || !FBB || !CmpInstr)
537    return nullptr;
538
539  // Simplify CondCode. Note this is only to simplify the findPath logic
540  // and will not change the instruction here.
541  switch (CC) {
542  case X86::COND_NE:
543    CC = X86::COND_E;
544    std::swap(TBB, FBB);
545    Modified = true;
546    break;
547  case X86::COND_LE:
548    if (CmpValue == INT_MAX)
549      return nullptr;
550    CC = X86::COND_L;
551    CmpValue += 1;
552    Modified = true;
553    break;
554  case X86::COND_GE:
555    if (CmpValue == INT_MIN)
556      return nullptr;
557    CC = X86::COND_G;
558    CmpValue -= 1;
559    Modified = true;
560    break;
561  default:
562    Modified = false;
563    break;
564  }
565  return std::make_unique<TargetMBBInfo>(TargetMBBInfo{
566      TBB, FBB, BrInstr, CmpInstr, CC, SrcReg, CmpValue, Modified, CmpBrOnly});
567}
568
569bool X86CondBrFoldingPass::runOnMachineFunction(MachineFunction &MF) {
570  const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
571  if (!ST.threewayBranchProfitable())
572    return false;
573  const X86InstrInfo *TII = ST.getInstrInfo();
574  const MachineBranchProbabilityInfo *MBPI =
575      &getAnalysis<MachineBranchProbabilityInfo>();
576
577  X86CondBrFolding CondBr(TII, MBPI, MF);
578  return CondBr.optimize();
579}
580